最优传输学习及问题总结

news2024/10/7 18:25:13

文章目录

  • 参考内容
  • lam=0.1
  • lam=3
  • lam=10
  • lam=50
  • lam=100
  • lam=300
  • 画图
  • 线性规划
    • matlab
    • python代码

参考内容

https://blog.csdn.net/qq_41129489/article/details/128830589
https://zhuanlan.zhihu.com/p/542379144

我主要想强调的是这个例子的解法存在的一些细节问题

lam=0.1

lam = 0.1

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

print(d)

PP = np.around(P,3) 
print(PP)

print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

结果如下
在这里插入图片描述

lam=3

lam = 3

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))
print(d)

PP = np.around(P,3) 
print(PP)

print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=10

lam = 10

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

print(d)

PP = np.around(P,3) 
print(PP)

print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=50

lam = 50

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

PP = np.around(P,3) 
print(PP)

print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=100

lam = 100

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

print(d)
PP = np.around(P,3) 
print(PP)

print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=300

lam = 300

P, d = compute_optimal_transport(M,
        r,
        c, lam=lam)

partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))

print(d)
PP = np.around(P,3) 
print(PP)

print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个就不接近了,之前的求和都是相差在0.001左右,可以近似看作相等
## 但是这个行和是 [2.    1.714 3.75  2.286 2.5   2.5   4.    1.25 ]
## 很明显是 [3. 3. 3. 4. 2. 2. 2. 1.]这个是不对的,所以lam=300时这个值已经发散了,
## 虽然此时的Sinkhorn distance是小于24的,但也不起作用

在这里插入图片描述

画图

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def compute_optimal_transport(M=None, r=None, c=None, lam=None, eplison=1e-8):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter

    Outputs:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    r = np.array([3, 3, 3, 4, 2, 2, 2, 1])
    c = np.array([4, 2, 6, 4, 4])
    M = np.array(
        [[2, 2, 1, 0, 0], 
        [0, -2, -2, -2, -2], 
        [1, 2, 2, 2, -1], 
        [2, 1, 0, 1, -1],
        [0.5, 2, 2, 1, 0], 
        [0, 1, 1, 1, -1], 
        [-2, 2, 2, 1, 1], 
        [2, 1, 2, 1, -1]],
        dtype=float) 
    M = -M # 将M变号,从偏好转为代价
    
    
    n, m = M.shape  # 8, 5
    P = np.exp(-lam * M) # (8, 5)
    P /= P.sum()  # 归一化
    u = np.zeros(n) # (8, )
    # normalize this matrix
    while np.max(np.abs(u - P.sum(1))) > eplison: # 这里是用行和判断收敛
        # 对行和列进行缩放,使用到了numpy的广播机制,不了解广播机制的同学可以去百度一下
        u = P.sum(1) # 行和 (8, )
        P *= (r / u).reshape((-1, 1)) # 缩放行元素,使行和逼近r
        v = P.sum(0) # 列和 (5, )
        P *= (c / v).reshape((1, -1)) # 缩放列元素,使列和逼近c
    return P, np.sum(P * M) # 返回分配矩阵和Sinkhorn距离

lam_list=[1,5,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150]

cost_list=[]
for lam in lam_list:
    P, d = compute_optimal_transport(lam=lam)
    cost_list.append(d)
print(cost_list)
plt.plot(np.array(lam_list),np.array(cost_list),c="g")
plt.show()

## 现在这个地方也有的

在这里插入图片描述
这个地方其实有一个画图的小问题,我待会要再写一下

可以看到大概是在lam =150的时候,就已经不稳定了,所以这个例子的问题的解的最小花费约等于24,但是我发现一个更有意思的问题,就是这个分配矩阵是唯一的吗,很显然不是的, 利用我上篇文章学到的线性规划,我发现matlab和python找到的是两个不同的解,

线性规划

matlab

clc;
clear;

r = [3, 3, 3, 4, 2, 2, 2, 1];
c = [4, 2, 6, 4, 4];
cost_matrix =  [2, 2, 1, 0, 0;
             0, -2, -2, -2, -2; 
               1, 2, 2, 2, -1;
               2, 1, 0, 1, -1;
              0.5, 2, 2, 1, 0;
               0, 1, 1, 1, -1;
               -2, 2, 2, 1, 1;
              2, 1, 2, 1, -1];

cost_matrix_t = (-1)*transpose(cost_matrix);% 需要有符号
cost_vec = cost_matrix_t(:);

raw_equ = zeros(8,40);
for i =1:8
    raw_equ(i,((i-1)*5+1):((i-1)*5+5))=1;
end

col_equ = zeros(5,40);
for i =1:5
    for j =1:8
        col_equ(i,i+(j-1)*5)=1;
    end
end

equ = [raw_equ;col_equ];
equ_value = horzcat(r, c);
% x1,x2,x3,x4,x5
% x6,x7,x8,x9,x10
% x11,x12,x13,x14,x15
% x16,x17,x18,x19,x20
% x21,x22,x23,x24,x25
% x26,x27,x28,x29,x30
% x31,x32,x33,x34,x35
% x36,x37,x38,x39,x40

% 现在我要求的变量是这样的,
f=cost_vec;			% 价值向量
a=[];	% a、b对应不等式的左边和右边
b=[];
aeq=equ;	% aeq和beq对应等式的左边和右边
beq=equ_value;
[x,y]=linprog(f,a,b,aeq,beq,zeros(40,1));

arr_mat = transpose(reshape(x',5,8));

结果如下
在这里插入图片描述
分配矩阵如下在这里插入图片描述

python代码

# Define parameters
m = 8
n = 5

p = np.array([3, 3, 3, 4, 2, 2, 2, 1])
q = np.array([4, 2, 6, 4, 4])

C = -1*np.array(
    [[2, 2, 1, 0, 0], 
    [0, -2, -2, -2, -2], 
    [1, 2, 2, 2, -1], 
    [2, 1, 0, 1, -1],
    [0.5, 2, 2, 1, 0], 
    [0, 1, 1, 1, -1], 
    [-2, 2, 2, 1, 1], 
    [2, 1, 2, 1, -1]],
    dtype=float)

# Vectorize matrix C
C_vec = C.reshape((m*n, 1), order='F')

# Construct matrix A by Kronecker product
A1 = np.kron(np.ones((1, n)), np.identity(m))
A2 = np.kron(np.identity(n), np.ones((1, m)))
A = np.vstack([A1, A2])

# Construct vector b
b = np.hstack([p, q])

# Solve the primal problem
res = linprog(C_vec, A_eq=A, b_eq=b)

# Print results
print("message:", res.message)
print("nit:", res.nit)
print("fun:", res.fun)
print("z:", res.x)
print("X:", res.x.reshape((m,n), order='F'))

结果如下
在这里插入图片描述
可以看到花费都是24,但是两者的分配矩阵并不一样哈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1400031.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《WebKit 技术内幕》之五(3): HTML解释器和DOM 模型

3 DOM的事件机制 基于 WebKit 的浏览器事件处理过程:首先检测事件发生处的元素有无监听者,如果网页的相关节点注册了事件的监听者则浏览器会将事件派发给 WebKit 内核来处理。另外浏览器可能也需要处理这样的事件(浏览器对于有些事件必须响应…

D - Left Right Operation

思路: 1、求前缀和 2、从后往前遍历,把某个后缀都变为R,记录最多让数组和减小多少 3、从前往后遍历,把某个前缀都变为L,记录最小答案(前i个变为L,后面的n-i个数让减小最多的后缀变为R&#x…

项目管理十大知识领域之项目采购管理

一、项目采购管理的定义与概述 项目采购管理是指在项目实施过程中,对相关产品、服务或工程进行采购的管理活动。其概述包括确定采购需求、制定采购计划、供应商选择、合同签订、供应管理和结算支付等环节。项目采购管理的定义还涉及对采购目标的明确界定&#xff0…

[小程序]使用代码渲染页面

一、条件渲染 1.单个控制 使用wx:if"{{条件}}"来判断是否需要渲染这段代码&#xff0c;同时可以结合wx:elif和wx:else来判断 <view wx:if"{{type0}}">0</view> <view wx:elif"{{type1}}">1</view> <view wx:else>…

数字IC后端设计实现 | PR工具中到底应该如何控制density和congestion?(ICC2Innovus)

吾爱IC社区星友提问&#xff1a;请教星主和各位大佬&#xff0c;对于一个模块如果不加干预工具会让inst挤成一团&#xff0c;后面eco修时序就没有空间了。如果全都加instPadding会导致面积不够overlap&#xff0c;大家一般怎么处理这种问题&#xff1f; 在数字IC后端设计实现中…

C#中ArrayList运行机制及其涉及的装箱拆箱

C#中ArrayList运行机制及其涉及的装箱拆箱 1.1 基本用法1.1.1 属性1.1.2 方法 1.2 内部实现1.3 装箱1.4 拆箱1.5 object对象的相等性比较1.6 总结1.7 其他简单结构类 1.1 基本用法 命名空间&#xff1a; using System.Collections; 1.1.1 属性 Capacity&#xff1a;获取或设…

Barrel Shifter RTL Combinational Circuit

在本博客中&#xff0c;将围绕许多设计中存在的非常有用的电路&#xff08;桶形移位器电路&#xff09;设计电路。将从最简单的方法开始实现固定位宽字的单向旋转桶形移位器&#xff0c;最后设计一个具有可参数化字宽的多功能双向桶形移位器电路。 Barrel Shifter 桶形移位器…

【Python学习】Python学习22- CGI编程

目录 【Python学习】Python学习22- CGI编程 前言CGI工作流程Web 服务器支持及配置Http头部参考 文章所属专区 Python学习 前言 本章节主要说明Python的CGI接口 CGI 目前由 NCSA 维护&#xff0c;NCSA 定义 CGI 如下&#xff1a; CGI(Common Gateway Interface)&#xff0c;通…

机器学习没那么难,Azure AutoML帮你简单3步实现自动化模型训练

在Machine Learning 这个领域&#xff0c;通常训练一个业务模型的难点并不在于算法的选择&#xff0c;而在于前期的数据清理和特征工程这些纷繁复杂的工作&#xff0c;训练过程中的问题在于参数的反复迭代优化。 AutoML 是 Azure Databricks 的一项功能&#xff0c;它自动的对…

GRU门控循环单元神经网络的MATLAB实现(含源代码)

在深度学习领域&#xff0c;循环神经网络&#xff08;RNN&#xff09;因其在处理序列数据方面的卓越能力而受到广泛关注。GRU&#xff08;门控循环单元&#xff09;作为RNN的一种变体&#xff0c;以其在捕捉时间序列长距离依赖关系方面的高效性而备受推崇。在本文中&#xff0c…

ros2学习笔记-CLI工具,记录命令对应操作。

目录 环境变量turtlesim和rqt以初始状态打开rqt node启动节点查看节点列表查看节点更多信息命令行参数 --ros-args topic话题列表话题类型话题列表&#xff0c;附加话题类型根据类型查找话题名查看话题发布的数据查看话题的详细信息查看类型的详细信息给话题发布消息&#xff0…

推荐两个工具:DeepSpeed-FastGen和DataTrove

DeepSpeed-FastGen 通过 MII 和 DeepSpeed-Inference 加速LLM生成文本 仓库地址&#xff1a;https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen GPT-4 和 LLaMA 等大型语言模型 (LLM) 已成为服务于各个级别的人工智能应用程序的主要工作负载。从一…

UE5 独立程序的网络TCP/UDP服务器与客户端基础流程

引擎源码版&#xff0c;复制\Engine\Source\Programs\路径下的BlankProgram空项目示例。 重命名BlankProgram&#xff0c;例如CustomTcpProgram&#xff0c;并修改项目名称。 修改.Build.cs内容 修改Target.cs内容 修改Private文件夹内.h.cpp文件名并修改.cpp内容 刷新引擎 …

SpringMVC获取参数与页面跳转

获取参数 第一种 直接当成方法的参数&#xff0c;需要与前台的name一致 相当于Request.getAttribute("username") Controller 第二种 使用对象接收 页面的name也要和对象的字段一致 创建一个对应的实体类 Controller 将参数更换为User对象就行 SpringMVC获取到…

【设计模式】你知道游戏SL大法是什么设计模式吗?

什么是备忘录模式&#xff1f; 老规矩&#xff0c;我们先来看看备忘录模式 (Memento) 的定义&#xff1a;在不破坏封装性的前提下&#xff0c;捕获一个对象的内部状态&#xff0c;并在该对象之外保存这个状态。这样以后就可将该对象恢复到原先保存的状态。 它的UML类图如下&a…

keep-alive组件缓存

keep-alive组件缓存 从a跳b&#xff0c;a已经销毁&#xff0c;b重新渲染&#xff1b;b跳a&#xff0c;b销毁a重新渲染 源组件销毁&#xff0c;目标组件渲染 组件缓存&#xff1a;组件实例等相关&#xff08; 包括vnode&#xff09;存储起来 重新渲染指的是&#xff1a;把视图重…

MySQL---多表查询综合练习

创建dept表 CREATE TABLE dept ( deptno INT(2) NOT NULL COMMENT 部门编号, dname VARCHAR (15) COMMENT 部门名称, loc VARCHAR (20) COMMENT 地理位置 ); 添加dept表主键 mysql> alter table dept add primary key(deptno); Query OK, 0 rows affected (0.02 s…

正则表达式初版

一、简介 REGEXP&#xff1a; Regular Expressions&#xff0c;由一类特殊字符及文本字符所编写的模式&#xff0c;其中有些字符&#xff08;元字符&#xff09;不表示字符字面意义&#xff0c;而表示控制或通配的功能&#xff0c;类似于增强版的通配符功能&#xff0c;但与通…

YOLOv5改进 | 主干篇 | 华为GhostnetV1一种移动端的专用特征提取网络

一、本文介绍 本文给大家带来的改进机制是华为移动端模型Ghostnetv1,华为GhostnetV1一种移动端的专用特征提取网络,旨在在计算资源有限的嵌入式设备上实现高性能的图像分类。GhostNet的关键思想在于通过引入Ghost模块,以较低的计算成本增加了特征图的数量,从而提高了模型的…

消除噪音:Chain-of-Note (CoN) 强大的方法为您的 RAG 管道提供强大动力

论文地址&#xff1a;https://arxiv.org/abs/2311.09210 英文原文地址&#xff1a;https://praveengovindaraj.com/cutting-through-the-noise-chain-of-notes-con-robust-approach-to-super-power-your-rag-pipelines-0df5f1ce7952 在快速发展的人工智能和机器学习领域&#x…