⌈ 传知代码 ⌋ 基于矩阵乘积态的生成模型

news2025/1/16 13:56:28

💛前情提要💛

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


📌导航小助手📌

  • 💡本章重点
  • 🍞一. 概述
  • 🍞二. 方法
  • 🍞三.实现
  • 🍞四.训练结果
  • 🫓总结


💡本章重点

  • 基于矩阵乘积态的生成模型

🍞一. 概述

生成模型,通过从数据中学习联合概率分布并据此生成样本,是机器学习和人工智能中的一个重要任务。受量子物理学中概率解释的启发,该文章提出了一种使用矩阵积状态的生成模型,这是一种最初用于描述(特别是一维)纠缠量子态的张量网络。其模型享有类似于密度矩阵重正化群方法的高效学习能力,该方法允许动态调整张量的维度,并提供了一种高效的直接采样方法用于生成任务。本文试图复现该文章的工作,利用该文章的思想,方法去实现MNIST手写数字的生成任务。

  • Han Z-Y, Wang J, Fan H, et al. Unsupervised Generative Modeling Using Matrix Product States[J]. Physical Review X, 2018, 8(3): 031012

在这里插入图片描述


🍞二. 方法

量子力学的概率解释自然地建议使用量子态来建模数据分布。假设我们将概率分布编码到一个量子波函数:

在这里插入图片描述
又在一定程度上能够表示更多不同种类的构型成为现在需要解决的问题。许多已经开发的表示方法和算法可以用于高效的概率建模。在这里,我们使用矩阵积状态(MPS)对波函数进行参数化:

在这里插入图片描述
上面的图示意思为,左边是我们需要表示的波函数,线代表它依赖的指标(或者变量),右边则是对应的MPS表示,两个方括号直接的连线代表求和,即将对应的指标(或者变量求和,类似于矩阵的乘积)进行收缩。我们可以看出我们把一个复杂的波函数变成了有限个3指标张量的收缩。


🍞三.实现

导入训练集(MNIST)

1000 张 MNIST 图像已存储为 mnist784_bin_1000.npy。

每张图像包含:n = 28 * 28 个像素,每个像素的取值为0或1。每张图像被视为维度为 2^n 的希尔伯特空间中的一个乘积态。

n = 784 
m = 1000
data = np.load("mnist784_bin_1000.npy").astype(np.int32)
data = data[:m,:]
data = torch.LongTensor(data)\
plt.figure(figsize=(10,2))
imgs = data.cpu().reshape([-1,28,28])
_, ax = plt.subplots(2,10)
for i in range(2): 
    for j in range(10):
        index = i * 2 + j
        if(a >= imgs.shape[0]):
            break
        ax[i][j].imshow(imgs[index,:,:],cmap='bone')
        ax[i][j].set_xticks([])
        ax[i][j].set_yticks([])
plt.show()

这可以让我们观察以下MNIST数据集的样子

在这里插入图片描述
定义MPS

现在我们要构造一个初始的MPS, 根据上面的阐述,我们的MPS是由一系列3指标的张量的所构成的,如下所示:

在这里插入图片描述

chi = 30 
mydevice = 'cuda' if torch.cuda.is_available() else torch.device("cpu")
print(mydevice)
data = data.to(mydevice)
bond_dims = [chi for i in range(n-1)]+[1]
tensors= [ torch.randn(bond_dims[i-1],2,bond_dims[i],device=mydevice) for i in range(n)]
  • 我们可以输出从而看到这些张量的输出维度

在这里插入图片描述
概率计算

概率计算可以遵循前面的Born公式,即:

在这里插入图片描述
在这里,带有一个小边(常称之为脚)是一个向量,代表的是对应像素的状态,是一个二维向量,用来表示对应的像素是黑还是白

现在难以计算的是配分函数,即:

在这里插入图片描述
这个东西,这涉及到张量网络的缩并,在张量网络这个领域中由非常多的缩并方式,一个常用的方法是正交化,即把MPS右边的那些三阶张量全部正交化使得他们收缩刚好是一个单位张量。这个过程如下:

在这里插入图片描述

通过不断的对左边的张量作用QR分解从而使得左边张量全部正交化(黄色的)。据此我们可以计算出对应的波函数:

def getPsi():
    psi = torch.ones([m, 1, 1], device=mydevice)
    for site in range(n):
        selected_tensor = tensors[site][:, data[:, site], :].permute(1, 0, 2)
        psi = torch.matmul(psi, selected_tensor)
    return psi

生成图片

生成图片的过程可以采用条件概率的方法,即先采样一个边缘概率,再从这个边缘概率对应的变量继续采样,重复这个过程即可:

在这里插入图片描述

核心代码为:

def generateSamples(batch):
    n = 784
    samples = torch.zeros([batch, n],device=mydevice)
    for site in range(n - 1):
        orthogonalize(site, True) 
    for s in range(batch):
        vec = torch.ones(1,1,device=mydevice)
        for site in range(n-1, -1, -1):
            vec = (tensors[site].view(-1, bond_dims[site]) @ vec).view(-1, 2)
            p0 = vec[:, 0].norm()**2 / (vec.norm()**2)
            x = (0 if np.random.rand() < p0 else 1)
            vec = vec[:, x]
            samples[s][site] = x
    return samples

🍞四.训练结果

在这里插入图片描述


🫓总结

综上,我们基本了解了“一项全新的技术啦” 🍭 ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读😆

后续还会继续更新💓,欢迎持续关注📌哟~

💫如果有错误❌,欢迎指正呀💫

✨如果觉得收获满满,可以点点赞👍支持一下哟~✨

【传知科技 – 了解更多新知识】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1968067.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL语句分类;查看MySQL存储引擎

文章目录 SQL语句分类查看MySQL存储引擎存储引擎对比 SQL语句分类 数据本身权限&#xff1a;定义 修改 DDL【data definition language】 数据定义语言&#xff0c;用来维护存储数据的结构 代表指令: create, drop, alterDML【data manipulation language】 数据操纵语言&…

第十九天内容

上午 1、构建vue发行版本 2、java环境配置 jdk软件包路径&#xff1a; https://download.oracle.com/java/22/latest/jdk-22_linux-x64_bin.tar.gz 下午 1、安装tomcat软件 tomcat软件包路径&#xff1a; https://dlcdn.apache.org/tomcat/tomcat-10/v10.1.26/bin/apache-to…

水库大坝安全自动监测系统位移测点布设

水库大坝安全自动监测系统中的位移测点布设是大坝安全监测的重要环节&#xff0c;其目的是为了及时、准确地获取大坝的位移信息&#xff0c;评估大坝的稳定性&#xff0c;确保大坝的安全运行。位移测点的布设需要综合考虑大坝的结构特点、地质条件、运行工况及监测需求等多方面…

tomcat多实例配置-Linux(CentOS)

多实例配置 一、安装 tomcat二、多实例配置 tomcat 官网 tomcat 安装包下载地址 一、安装 tomcat 解压tomcat压缩包到 /usr/local 下 tar xf apache-tomcat-*.gz -C /usr/local/# 可选 添加一个软链接&#xff0c;方便查找 ln -s /usr/local/apache-tomcat-* /usr/local/tom…

python实现发票信息识别和处理

公司需要发票报销&#xff0c;一定周期的发票攒在一起&#xff0c;处理报销单特别繁琐&#xff0c;遂萌生用python简化报销流程。 明确需求 公司报销单需要发票代码(短码)&#xff0c;金额&#xff0c;总计金额&#xff0c;如下图 开始编码 首先需要一个读取pdf的类库 pdf…

AEAD:AES-CCM简介

目录 1. CCM模式 2.认证加密过程 3.校验解密过程 1. CCM模式 CCM&#xff08;Counter with CBC-MAC&#xff09; 首先使用 CBC-MAC 来保证数据完整性和真实性&#xff0c;然后使用 CTR 模式来保证数据机密性。 在CCM中&#xff0c;受保护的数据被称为payload&#xff0c;简…

虚拟机(CentOS7)安装gitlab

GitLab官方安装教程 链接&#xff1a;https://gitlab.cn/install/ 1、关闭虚拟机防火墙 # 关闭防火墙命令 systemctl stop firewalld # 查看当前防火墙的状态信息 systemctl status firewalld成功关闭 2、GitLab安装包下载 # windows下载地址&#xff1a; https://mirrors.t…

JVM—对象已死?

在堆里面存放着 Java 世界中几乎所有的对象实例,垃圾收集器在对堆进行回收前,第一件事情就是要确定这些对象之中哪些还“存活”着,哪些已经“死去”。 1、如何判断对象存活 1.1 引用计数法 给对象增加一个引用计数器&#xff0c;当对象被引用一次计数器加一、当引用失效时计数…

深入源码P3C-PMD:使用流程(1)

PMD开源组件启动流程介绍 在软件开发领域&#xff0c;代码质量是项目成功的关键因素之一。为了提升代码质量&#xff0c;开发者们常常借助各种工具进行代码分析和检查。PMD作为一款开源的静态代码分析工具&#xff0c;在Java、JavaScript、PLSQL等语言项目中得到了广泛应用。本…

虚拟主机与vue项目、samba磁盘映射、nfs共享

1、复习 &#xff08;1&#xff09;tomcat服务器需要jdk环境 版本对应 tomcat9》jdk1.8 tomcat10》jdk17 配置系统变量JAVA_HOME sed -i $a export JAVA_HOME/usr/local/jdk22/ /etc/profile sed -i $a export PATHJAVA_HOME/bin:$PATH /etc/profile source /etc/profile…

基于FPGA的出租车计费系统设计---第一版--郝旭帅电子设计团队

欢迎各位朋友关注“郝旭帅电子设计团队”&#xff0c;本篇为各位朋友介绍基于FPGA的出租车计费系统设计—第一版 功能说明&#xff1a; 收费标准&#xff08;里程&#xff09;&#xff1a;起步价5元&#xff0c;包括三公里&#xff1b;三公里之后&#xff0c;每公里2元&#x…

JVM: 堆上的数据存储

文章目录 一、对象在堆中的内存布局1、对象在堆中的内存布局 - 标记字段2、JOL打印内存布局 二、元数据指针 一、对象在堆中的内存布局 对象在堆中的内存布局&#xff0c;指的是对象在堆中存放时的各个组成部分&#xff0c;主要分为以下几个部分&#xff1a; 1、对象在堆中的…

Java SpringTask定时自动化处理

目录 一、自动化处理 1.1 什么是自动化处理 1.2 SpringTask介绍 二、SpringTask的基本使用 2.1 引入依赖 2.2 通过控制台加入注解启用SpringTask 2.3 使用Cron表达式规定时间 2.4 通过Schedule(Cron表达式) 实现定时任务&#xff08;每两秒执行一次&#xff09; 三、实…

【完美解决】 TypeError: ‘str’ object does not support item assignment

【完美解决】 TypeError: ‘str’ object does not support item assignment 在Python编程中&#xff0c;遇到TypeError: str object does not support item assignment这样的错误通常意味着你试图修改字符串中的某个字符&#xff0c;但字符串是不可变类型&#xff0c;不支持这…

【每日一题 | 组成原理】补码溢出判断

题目 题型总结 带符号的定点数表示方式有4种&#xff0c;分别是原码、反码、补码和移码&#xff0c;他们都由两部分组成&#xff0c;分别是符号位和数值位&#xff0c;这四种编码方式非常重要&#xff0c;要熟练掌握他们之间的转换和与真值间的转换。这里我们重点看一下补码&a…

408-部分知识点笔记(自用)

一、操作系统部分 1.内中断&#xff08;异常&#xff09;和外中断&#xff08;中断&#xff09; 1.1 异常&#xff08;内中断&#xff09; 异常就是指CPU内部发生的中断&#xff0c;与当前正在执行的程序有关。类似的内中断有&#xff1a;缺页中断、算法溢出、除以0错误、存…

可视化目标检测算法推理部署(三)YOLOv8模型视频推理

在上一章节中博主利用Gradio完成了YOLOv8模型的图像推理&#xff0c;那么在本章节中将进行视频推理&#xff0c;其代码十分简单&#xff0c;只需要将原本的视频切分为一帧帧图像再去检测即可&#xff0c;代码如下&#xff1a; def detectio_video(input_path):output_path&quo…

[C++]多态与虚函数

一、多态的概念 顾名思义&#xff0c;多态的意思就是一个事物有多种形态&#xff0c;在完成某个行为的时候&#xff0c;当不同的对象去完成时会产生不同的状态。在面向对象方法中一般是这样表示多态的&#xff1a;向不同的对象发送同一条消息&#xff0c;不同的对象在接收时会产…

记录|Stock编程

目录 前言一、Stock编程&#xff1f;二、聊天工具开发1. 目的2. 服务器端开启对端口的监听3. VS创建服务器端ServiceStep1. 创建Step2. Listener对象监听事件Step1~2效果展示 4. 创建客户端&#xff0c;与服务器端链接5. VS创建客户端ClientStep1. 创建Step2. Client对象Step1~…