人工智能(pytorch)搭建模型13-pytorch搭建RBM(受限玻尔兹曼机)模型,调通模型的训练与测试

news2024/12/23 22:08:39

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型13-pytorch搭建RBM(受限玻尔兹曼机)模型,调通模型的训练与测试。RBM(受限玻尔兹曼机)可以在没有人工标注的情况下对数据进行学习。其原理类似于我们人类学习的过程,即通过观察、感知和记忆不同事物的特点,从而形成对这些事物的认知模型。本文将介绍RBM(受限玻尔兹曼机)模型的原理,并使用PyTorch框架实现一个简单的RBM模型。我们将介绍如何构建模型,加载样例数据进行训练,以及在训练完成后进行测试。

文章目录结构:

  1. RBM模型简介
  2. RBM模型原理
  3. 使用PyTorch搭建RBM模型
  4. 数据样例及加载
  5. 模型训练
  6. 模型测试
  7. 总结

1. RBM模型简介

受限玻尔兹曼机(RBM)是一种生成式随机神经网络,广泛应用于图像识别、语音识别、推荐系统等领域。RBM能够学习到数据的潜在表示,是深度学习的重要组成部分。

RBM 由一些可见变量和一些隐藏变量组成。它的基本思想是用一个二分图表示这些变量之间的关系。可见变量与隐藏变量之间没有边相连,而可见变量与其他可见变量、隐藏变量与其他隐藏变量之间都存在边相连。这种二分图结构使得 RBM 可以很好地对输入数据进行建模。

在训练阶段,RBM 的目标是学习一个能量模型,使得训练数据的概率最大化。为了实现这个目标,通常使用下降梯度的方法来最小化负对数似然函数(Negative Log-Likelihood,NLL),从而得到隐含层向量和可见层向量之间的权重和偏置值。当模型参数学习完成后,我们可以使用 RMB 对新的数据进行生成、降噪等处理。

RBM 能够有效地应用于很多领域,例如语音识别、图像处理、自然语言处理等。同时,它还是其他深度学习模型的基础,例如深度信念网络(Deep Belief Network,DBN)和深度玻尔兹曼机(Deep Boltzmann Machine,DBM)等。
在这里插入图片描述

2. RBM模型原理

RBM是一个二部图模型,包括可见层(visible layer)和隐藏层(hidden layer),两层之间存在连接权重。可见层负责接收原始数据,隐藏层负责提取特征。与其他神经网络不同,RBM没有输出层,其学习过程是无监督的。

RBM的训练过程包括正向传播(从可见层到隐藏层)和反向传播(从隐藏层到可见层)。训练目标是最大化数据的对数似然,通过对比散度(Contrastive Divergence,简称CD)算法进行权重更新。

受限玻尔兹曼机(RBM)是一种用于无监督学习的概率生成模型。它由可见层和隐藏层组成,通过学习数据的分布来捕捉数据中的特征。

RBM的数学原理可以通过以下公式表示:

可见层的状态:
P ( v ) = 1 Z ∑ h e − E ( v , h ) P(v) = \frac{1}{Z} \sum_h e^{-E(v,h)} P(v)=Z1heE(v,h)

隐藏层的状态:
P ( h ) = 1 Z ∑ v e − E ( v , h ) P(h) = \frac{1}{Z} \sum_v e^{-E(v,h)} P(h)=Z1veE(v,h)

其中,$ E(v,h) $ 是能量函数,$ Z $是归一化常数。

RBM的学习过程主要包括两个步骤:正向传播和反向传播。

正向传播(Positive Phase):

在正向传播中,给定一个可见层的输入样本,通过计算隐藏层的激活概率来更新隐藏层的状态。

P ( h j = 1 ∣ v ) = σ ( ∑ i = 1 n v w i j v i + c j ) P(h_j=1|v) = \sigma\left(\sum_{i=1}^{n_v} w_{ij} v_i + c_j\right) P(hj=1∣v)=σ(i=1nvwijvi+cj)

其中,$ \sigma(x) $是sigmoid函数。

反向传播(Negative Phase):

在反向传播中,通过计算可见层的激活概率来更新可见层和隐藏层的状态。

P ( v i = 1 ∣ h ) = σ ( ∑ j = 1 n h w i j h j + b i ) P(v_i=1|h) = \sigma\left(\sum_{j=1}^{n_h} w_{ij} h_j + b_i\right) P(vi=1∣h)=σ(j=1nhwijhj+bi)

通过交替进行正向传播和反向传播,RBM可以学习到数据的分布,并用于生成新的样本。

3. 使用PyTorch搭建RBM模型

首先导入需要的库:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

接下来定义RBM模型:

class RBM(nn.Module):
    def __init__(self, visible_dim, hidden_dim, k=1):
        super(RBM, self).__init__()
        self.visible_dim = visible_dim
        self.hidden_dim = hidden_dim
        self.k = k
        self.W = nn.Parameter(torch.randn(visible_dim, hidden_dim) * 0.01)
        self.visible_bias = nn.Parameter(torch.zeros(visible_dim))
        self.hidden_bias = nn.Parameter(torch.zeros(hidden_dim))

    def sample_hidden(self, visible_probs):
        hidden_probs = torch.sigmoid(torch.matmul(visible_probs, self.W) + self.hidden_bias)
        return torch.bernoulli(hidden_probs)

    def sample_visible(self, hidden_probs):
        visible_probs = torch.sigmoid(torch.matmul(hidden_probs, self.W.t()) + self.visible_bias)
        return torch.bernoulli(visible_probs)

    def contrastive_divergence(self, visible):
        v0 = visible
        h0 = self.sample_hidden(v0)
        v_k = v0.clone()
        for _ in range(self.k):
            h_k = self.sample_hidden(v_k)
            v_k = self.sample_visible(h_k)
        return v0, h0, v_k

    def forward(self, visible):
        v0, h0, v_k = self.contrastive_divergence(visible)
        positive_association = torch.matmul(v0.t(), h0)
        negative_association = torch.matmul(v_k.t(), self.sample_hidden(v_k))
        return positive_association - negative_association

4. 数据样例及加载

为了简化问题,我们使用二值化的MNIST数据集作为示例。数据集包含手写数字0-9的灰度图像,每个图像的大小为28x28。我们需要将数据转换为可见层的形式。

from torchvision import datasets, transforms

def bernoulli(x):
    return torch.bernoulli(x)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(bernoulli)
])

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=5, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=5, shuffle=False, num_workers=2)

5. 模型训练

接下来,我们将训练RBM模型。设置超参数,实例化RBM模型,然后使用随机梯度下降(SGD)优化器进行训练。SGD 是一种常用的优化算法,其基本思想是在每个迭代步骤中,通过计算当前样本的梯度来更新模型参数,以逐步寻找最小化损失函数的全局最优解。

visible_dim = 28 * 28
hidden_dim = 128
k = 1
learning_rate = 0.01
epochs = 10

rbm = RBM(visible_dim, hidden_dim, k)
optimizer = optim.SGD(rbm.parameters(), lr=learning_rate)

for epoch in range(epochs):
    train_loss = 0
    for i, (data, _) in enumerate(train_loader):
        data = data.view(-1, visible_dim)
        optimizer.zero_grad()
        delta_W = rbm(data)
        loss = -torch.mean(delta_W)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss / (i + 1)}")

6. 模型测试

在模型训练完成后,我们可以将其应用于实际任务,如特征提取、分类等。这里我们简单地展示如何使用训练好的RBM模型重构测试数据。

import matplotlib.pyplot as plt

def display_reconstruction(rbm, test_loader, num_images=5):
    _, (test_data, _) = next(enumerate(test_loader))
    test_data = test_data[:num_images].view(-1, visible_dim)
    v0, _, v_k = rbm.contrastive_divergence(test_data)

    fig, axes = plt.subplots(nrows=2, ncols=num_images, figsize=(10, 4))
    for i in range(num_images):
        axes[0, i].imshow(v0[i].view(28, 28).detach().numpy(), cmap='gray')
        axes[1, i].imshow(v_k[i].view(28, 28).detach().numpy(), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].axis('off')
    plt.show()

display_reconstruction(rbm, test_loader)

运行结果:

Epoch 1/10, Loss: 0.927296216373558
Epoch 2/10, Loss: 0.9289948132250097
Epoch 3/10, Loss: 0.9284022589268146
Epoch 4/10, Loss: 0.9277208608952198
Epoch 5/10, Loss: 0.9270475412525021
Epoch 6/10, Loss: 0.9267477485059382
Epoch 7/10, Loss: 0.9266238975358176
Epoch 8/10, Loss: 0.9262511341960042
Epoch 9/10, Loss: 0.9246195605427593
Epoch 10/10, Loss: 0.9238044374525011

在这里插入图片描述

7. 总结

本文介绍了RBM模型的原理,并使用PyTorch框架实现了一个简单的RBM模型。我们展示了如何构建模型,加载样例数据进行训练,并在训练完成后进行测试。

需要注意的是,RBM模型在现代深度学习中的应用已经较少,很多任务可以通过其他神经网络模型(如卷积神经网络、循环神经网络)达到更好的效果。但了解RBM模型及其原理对理解深度学习的发展历程具有重要意义。

受限玻尔兹曼被广泛运用在各种领域中,以下是其中的一些应用场景:

图像处理和计算机视觉:RBM 可以用于图像特征提取、图像分类、图像生成等任务,例如人脸识别、手写数字识别等。

语音识别:RBM 可以用于建立声学模型,从而提高语音识别的准确性和鲁棒性。

自然语言处理:RBM 可以用于语义表示、文本分类、机器翻译等任务。

推荐系统:RBM 可以用于用户画像建模、商品推荐等场景,从而提供更精准的个性化推荐服务。

数据分析和挖掘:RBM 可以用于数据特征提取、异常检测、聚类分析等任务,例如金融数据分析、医疗数据分析等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/670750.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis简单动态字符串SDS

目录 前言 一.SDS定义 二.SDS与C字符串的区别 2.1 常数复杂度获取字符串的长度 2.2 杜绝缓冲区溢出 2.3 减少修改字符串时带来的内存重分配次数 2.3.1 空间预分配 2.3.2 惰性空间释放 2.4 二进制安全 2.5 兼容部分C字符串函数 2.6 总结 三.SDS缺点 前言 Redis没有直接使用C语…

gRPC教程与应用

gRPC是是谷歌一个开源的跨语言的RPC框架,面向移动和 HTTP/2 设计。 grpc中文网 在 gRPC 里客户端应用可以像调用本地对象一样直接调用另一台不同的机器上服务端应用的方法,使得您能够更容易地创建分布式应用和服务。 gRPC 也是基于以下理念&#xff1…

python3+requests+unittest接口自动化测试

1.环境准备 python3 pycharm编辑器 2.框架目录展示 (该套代码只是简单入门,有兴趣的可以不断后期完善) (1)run.py主运行文件,运行之后可以生成相应的测试报告,并以邮件形式发送;…

【C++进阶】红黑树实现

文章目录 红黑树的概念红黑树的性质红黑树节点的定义红黑树结构红黑树的插入1.按照二叉搜索的树规则插入新节点2.进行旋转和变色源码 红黑树的验证中序遍历判断是否满足二叉搜索树判断是否满足红黑树 完整源码 红黑树的概念 红黑树,是一种二叉搜索树,但…

基于spss的多元统计分析 之 单/双因素方差分析 多元回归分析(1/8)

实验目的: 1.掌握单样本t检验、两样本t检验、配对样本t检验、单因素方差分析、多元回归分析的基本原理; 2.熟悉掌握SPSS软件或者R软件关于单因素、多因素方差分析、多元回归分析的基本操作; 3.利用实验指导…

2.3C++保护成员

C 保护成员 在C中,可以使用保护成员 protected,来提高代码的安全性。 我用大白话解释一下什么是保护成员:说白了就是为了防止其他类直接访问或修改其成员加的一个措施。 目的是保护,成员的私有性和可见性。 C 类的保护 可以为…

web 语音通话 jssip

先把封装好的地址安上(非本人封装):webrtc-webphone: 基于JsSIP开发的webrtc软电话 jssip中文文档:jssip中文开发文档(完整版) - 简书 jssip使用文档:(我没有运行过,但…

Nginx服务器,在window系统中的使用(前端,nginx的应用)

简介:Nginx是一个轻量级、高性能的HTTP和反向代理web服务器,且支持电子邮件(IMAP/POP3)代理服务,特点是占用内存少,并发能力强,给我们来了很多的便利,国内大部分网站都有使用nginx&a…

18款奔驰S350升级后排座椅记忆功能,提升您乘坐舒适性

带有记忆功能的座椅可以存储三个的座椅设置以及行车电脑中的舒适性设置。只要按一下按钮就可以跳到记忆模式,让座椅回到上一次设置。

使用 BigQuery Omni,发现跨云地理空间分析的优势

【本文由 Cloud Ace 整理发布。Cloud Ace 是谷歌云全球战略合作伙伴,拥有 300 多名工程师,也是谷歌最高级别合作伙伴,多次获得 Google Cloud 合作伙伴奖。作为谷歌托管服务商,我们提供谷歌云、谷歌地图、谷歌办公套件、谷歌云认证…

第十章详解synchronized锁升级

文章目录 升级的流程为什么要引入锁升级这套流程多线程访问情况具体流程 轻量级锁如何使用CAS实现轻量级锁CAS加锁成功CAS加锁失败CAS进行解锁 总结何时变为重量级锁 锁膨胀自旋优化 偏向锁主要作用偏向状态测试撤销偏向锁 撤销 - 调用对象 hashCode撤销 - 其它线程使用对象撤销…

js:codemirror实现在线代码编辑器代码高亮显示

CodeMirror is a versatile text editor implemented in JavaScript for the browser. It is specialized for editing code, and comes with a number of language modes and addons that implement more advanced editing functionality. 译文:CodeMirror是一个多…

第二章:软件工程师必备的网络基础

目录 一、网线的制作 二、集线器、交换机介绍 三、路由器的配置 一、网线的制作 1.1、水晶头 ​​​ 1.2、网线钳 1.3、网线的标准 T568A标准(交叉线): 适用链接场合:电脑-电脑、交换机-交换机、集线器-集线器 接线顺序&…

【正点原子STM32连载】第三十九章 触摸屏实验 摘自【正点原子】STM32F103 战舰开发指南V1.2

1)实验平台:正点原子stm32f103战舰开发板V4 2)平台购买地址:https://detail.tmall.com/item.htm?id609294757420 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html# 第三…

有源电力滤波器及配电能效平台在污水处理厂中的应用

【摘要】为减少污水处理设备产生的各次谐波,通过确定主要谐波源,检测和计算谐波分量,采用有源电力滤波器进行谐波治理,大幅降低了电力系统中的三相电流畸变率,提高了电能质量;抑制了谐波分量,减…

doris docker部署和本地化部署 1.2.4.1版本

写在前面 以下操作语句按顺序执行即可,注意切换目录的命令一定记得执行,如果需要改动的地方会有${}注释,其余不需要任何改动,默认安装版本为1.12.4(稳定版) 本地化部署 下载 # 创建目录 mkdir /data/sof…

软件测试日常工作和前景是怎么样的?

笔者从测试的工作情况,职业发展,还有测试的工作日常等等来给大家讲解一下软件测试到底是什么样的工作? 通俗来说软件测试工程师就相当于一个质检员,专门处理软件测试质量的工作,不管是功能测试也好,性能测…

BK7231N开发平台原厂烧录工具使用说明

BK7231N开发平台原厂烧录工具使用说明 烧录流程介绍 1.打开原厂烧录工具 以管理员身份打开名为 bk_writer_gui_V1.6.3.exe 的可执行文件。 2. 烧录对象 烧录对象选择 BK7231n 3.烧录地址 当我们烧录UA文件的时候,需要把起始地址设置为: 0X00011000。…

Windows提示“找不到rgss202j.dll”怎么办?

Rgss202j.dll文件是Windows操作系统最重要的系统文件之一,它包含了一组程序和驱动函数。如果此文件丢失或损坏,驱动程序将无法正常工作,并且相应的应用程序也将无法正常启动且运行。通常情况下,造成Rgss202j.dll文件无法找到的原因…

爬虫 - ProtoBuf 协议

一、抓取请求 以下是请求的大致内容: 是乱码,需要解析。 二、解析 通过分析 request 和 response 的 Content-Type: application/x-protobuf 得知:使用了谷歌的 protobuf 协议来传输数据,需要破解。 大致破解过程&#xff…