TensorFlow 实现 Mixture Density Network (MDN) 的完整说明

news2025/7/16 17:09:46

本文档详细解释了一段使用 TensorFlow 构建和训练混合密度网络(Mixture Density Network, MDN)的代码,涵盖数据生成、模型构建、自定义损失函数与预测可视化等各个环节。


1. 导入库与设置超参数

import numpy as np 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import math

说明

  • 引入用于数值运算(NumPy)、构建深度学习模型(TensorFlow/Keras)和绘图(Matplotlib)的基础工具包。

超参数定义

N_HIDDEN = 15         # 隐藏层神经元数量
N_MIXES = 10          # GMM 中混合成分数量
OUTPUT_DIMS = 1       # 输出维度(目标变量维度)

2. 自定义 MDN 层

class MDN(layers.Layer):
    def __init__(self, output_dims, num_mixtures, **kwargs):
        super(MDN, self).__init__(**kwargs)
        self.output_dims = output_dims
        self.num_mixtures = num_mixtures
        self.params = self.num_mixtures * (2 * self.output_dims + 1)  # pi, mu, sigma
        self.dense = layers.Dense(self.params)

    def call(self, inputs):
        output = self.dense(inputs)
        return output

说明

  • params 表示 GMM 每个分量包含 mu(均值)、sigma(标准差)和 pi(权重),共 2*D + 1 个参数。
  • 输出维度为 (batch_size, num_mixtures * (2*output_dims + 1))

3. 自定义 MDN 损失函数

def get_mixture_loss_func(output_dims, num_mixtures):
    def mdn_loss(y_true, y_pred):
        y_true = tf.reshape(y_true, [-1, 1])
        out_mu = y_pred[:, :num_mixtures * output_dims]
        out_sigma = y_pred[:, num_mixtures * output_dims:2 * num_mixtures * output_dims]
        out_pi = y_pred[:, -num_mixtures:]

        mu = tf.reshape(out_mu, [-1, num_mixtures, output_dims])
        sigma = tf.exp(tf.reshape(out_sigma, [-1, num_mixtures, output_dims]))
        pi = tf.nn.softmax(out_pi)

        y_true = tf.tile(y_true[:, tf.newaxis, :], [1, num_mixtures, 1])
        normal_dist = tf.exp(-0.5 * tf.square((y_true - mu) / sigma)) / (sigma * tf.sqrt(2.0 * np.pi))
        prob = tf.reduce_prod(normal_dist, axis=2)
        weighted_prob = prob * pi
        loss = -tf.math.log(tf.reduce_sum(weighted_prob, axis=1) + 1e-8)
        return tf.reduce_mean(loss)
    return mdn_loss

说明

  • 通过概率密度函数计算目标值属于 GMM 各个分布的概率,并取加权平均。
  • 对数似然函数取负作为损失。

4. 从输出分布中采样

def sample_from_output(y_pred, output_dims, num_mixtures, temp=1.0):
    out_mu = y_pred[:num_mixtures * output_dims]
    out_sigma = y_pred[num_mixtures * output_dims:2 * num_mixtures * output_dims]
    out_pi = y_pred[-num_mixtures:]

    out_sigma = np.exp(out_sigma)
    out_pi = np.exp(out_pi / temp)
    out_pi /= np.sum(out_pi)

    mixture_idx = np.random.choice(np.arange(num_mixtures), p=out_pi)
    mu = out_mu[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]
    sigma = out_sigma[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]
    sample = np.random.normal(mu, sigma)
    return sample

说明

  • 使用 softmax 处理 pi,选择一个分布后按对应的 musigma 采样。
  • temp 控制采样温度(温度越高分布越平坦)。

5. 生成训练数据

NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))
y_data = y_data.reshape((NSAMPLE, 1))

说明

  • 构造非线性映射关系的合成数据:x = sin(0.75y)*7 + 0.5y + 噪声
  • x 是输入,y 是目标。

6. 构建模型

model = keras.Sequential([
    layers.Dense(N_HIDDEN, input_shape=(1,), activation='relu'),
    layers.Dense(N_HIDDEN, activation='relu'),
    MDN(OUTPUT_DIMS, N_MIXES)
])
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

说明

  • 构建一个两层隐层的前馈神经网络,输出 MDN 层。
  • 使用自定义的 MDN 损失函数训练模型。

7. 模型训练

model.fit(x_data, y_data, batch_size=128, epochs=200, validation_split=0.15, verbose=1)
  • 批量大小 128,训练 200 个 epoch,保留 15% 数据用于验证。

8. 模型测试与预测可视化

x_test = np.linspace(-15, 15, 1000).astype(np.float32).reshape(-1, 1)
y_pred = model.predict(x_test)
y_samples = np.array([sample_from_output(p, OUTPUT_DIMS, N_MIXES) for p in y_pred])
  • 对连续输入进行预测并从预测的 GMM 中采样。

可视化预测结果

plt.figure()
plt.scatter(x_test, y_samples, alpha=0.3, s=10)
plt.title("MDN Predictions")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

原始数据与预测对比

plt.figure(figsize=(8, 5))
plt.scatter(x_data, y_data, label="Original Data", alpha=0.2, s=10)
plt.scatter(x_test, y_samples, label="MDN Samples", alpha=0.5, s=10, color='r')
plt.title("MDN Prediction vs Training Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

总代码如下

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import math

# 超参数
N_HIDDEN = 15
N_MIXES = 10
OUTPUT_DIMS = 1

# === 1. 自定义 MDN 层 ===
class MDN(layers.Layer):
    def __init__(self, output_dims, num_mixtures, **kwargs):
        super(MDN, self).__init__(**kwargs)
        self.output_dims = output_dims
        self.num_mixtures = num_mixtures
        self.params = self.num_mixtures * (2 * self.output_dims + 1)  # pi, mu, sigma
        self.dense = layers.Dense(self.params)

    def call(self, inputs):
        output = self.dense(inputs)
        return output

# === 2. 自定义损失函数 ===
def get_mixture_loss_func(output_dims, num_mixtures):
    def mdn_loss(y_true, y_pred):
        y_true = tf.reshape(y_true, [-1, 1])
        out_mu = y_pred[:, :num_mixtures * output_dims]
        out_sigma = y_pred[:, num_mixtures * output_dims:2 * num_mixtures * output_dims]
        out_pi = y_pred[:, -num_mixtures:]

        mu = tf.reshape(out_mu, [-1, num_mixtures, output_dims])
        sigma = tf.exp(tf.reshape(out_sigma, [-1, num_mixtures, output_dims]))
        pi = tf.nn.softmax(out_pi)

        y_true = tf.tile(y_true[:, tf.newaxis, :], [1, num_mixtures, 1])
        normal_dist = tf.exp(-0.5 * tf.square((y_true - mu) / sigma)) / (sigma * tf.sqrt(2.0 * np.pi))
        prob = tf.reduce_prod(normal_dist, axis=2)
        weighted_prob = prob * pi
        loss = -tf.math.log(tf.reduce_sum(weighted_prob, axis=1) + 1e-8)
        return tf.reduce_mean(loss)
    return mdn_loss

# === 3. 从输出采样函数 ===
def sample_from_output(y_pred, output_dims, num_mixtures, temp=1.0):
    out_mu = y_pred[:num_mixtures * output_dims]
    out_sigma = y_pred[num_mixtures * output_dims:2 * num_mixtures * output_dims]
    out_pi = y_pred[-num_mixtures:]

    out_sigma = np.exp(out_sigma)
    out_pi = np.exp(out_pi / temp)
    out_pi /= np.sum(out_pi)

    mixture_idx = np.random.choice(np.arange(num_mixtures), p=out_pi)
    mu = out_mu[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]
    sigma = out_sigma[mixture_idx * output_dims:(mixture_idx + 1) * output_dims]
    sample = np.random.normal(mu, sigma)
    return sample

# === 4. 生成训练数据 ===
NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))
y_data = y_data.reshape((NSAMPLE, 1))

plt.figure()
plt.scatter(x_data, y_data, alpha=0.3)
plt.title("Training Data")
plt.show()

# === 5. 构建模型 ===
model = keras.Sequential([
    layers.Dense(N_HIDDEN, input_shape=(1,), activation='relu'),
    layers.Dense(N_HIDDEN, activation='relu'),
    MDN(OUTPUT_DIMS, N_MIXES)
])
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMS, N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

# === 6. 模型训练 ===
model.fit(x_data, y_data, batch_size=128, epochs=200, validation_split=0.15, verbose=1)

# === 7. 测试与可视化 ===
x_test = np.linspace(-15, 15, 1000).astype(np.float32).reshape(-1, 1)
y_pred = model.predict(x_test)
y_samples = np.array([sample_from_output(p, OUTPUT_DIMS, N_MIXES) for p in y_pred])

plt.figure()
plt.scatter(x_test, y_samples, alpha=0.3, s=10)
plt.title("MDN Predictions")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
# === 8. 测试数据与预测对比图 ===

plt.figure(figsize=(8, 5))
plt.scatter(x_data, y_data, label="Original Data", alpha=0.2, s=10)
plt.scatter(x_test, y_samples, label="MDN Samples", alpha=0.5, s=10, color='r')
plt.title("MDN Prediction vs Training Data")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

总结

本项目展示了如何使用 TensorFlow 构建混合密度网络,用以建模复杂的条件分布。相比传统回归模型,MDN 能够生成多峰预测结果,适用于不确定性高、输出存在多解的场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2338696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

xml+html 概述

1.什么是xml xml 是可扩展标记语言的缩写&#xff1a; Extensible Markup Language。 <root><h1> text 1</h1> </root> web 应用开发&#xff0c;需要配置 web.xml&#xff0c;就是个典型的 xml文件 <web-app><servlet><servlet-name&…

Java从入门到“放弃”(精通)之旅——数组的定义与使用⑥

Java从入门到“放弃”&#xff08;精通&#xff09;之旅&#x1f680;——数组⑥ 前言——什么是数组&#xff1f; 数组&#xff1a;可以看成是相同类型元素的一个集合&#xff0c;在内存中是一段连续的空间。比如现实中的车库&#xff0c;在java中&#xff0c;包含6个整形类…

如何对docker镜像存在的gosu安全漏洞进行修复——筑梦之路

这里以mysql的官方镜像为例进行说明&#xff0c;主要流程为&#xff1a; 1. 分析镜像存在的安全漏洞具体是什么 2. 根据分析结果有针对性地进行修复处理 3. 基于当前镜像进行修复安全漏洞并复核验证 # 镜像地址mysql:8.0.42 安全漏洞现状分析 dockerhub网站上获取该镜像的…

基于springboot的老年医疗保健系统

博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业六年&#xff0c;熟悉各种主流语言&#xff0c;精通java、python、php、爬虫、web开发&#xff0c;已经做了六年的毕业设计程序开发&#xff0c;开发过上千套毕业设计程序&#xff0c;没有什么华丽的语言&#xff0…

使用Ollama本地运行deepseek模型

Ollama 是一个用于管理 AI 模型的工具 下载 Ollama Ollama 选择版本 下载模型 安装好后&#xff0c;下载模型 选择模型 选择模型大小&#xff0c;复制对应命令&#xff08;越大越聪明&#xff0c;但是内存要求越高&#xff09; 打开控制台运行命令&#xff0c;第一次运行会自动…

网络编程 - 3

目录 UDP 连接拓展&#xff08;业务逻辑&#xff09; 词典服务器实现 完 UDP 连接拓展&#xff08;业务逻辑&#xff09; 我们上一篇文章实现了一个回显服务器&#xff0c;在服务端中业务方法 process 中&#xff0c;只是单纯的将客户端输入的东西 return 了一下&#xff0…

5G 毫米波滤波器的最优选择是什么?

新的选择有很多&#xff0c;但到目前为止还没有明确的赢家。 蜂窝电话技术利用大量的带带&#xff0c;为移动用途提供不断增加的带宽。 其中的每一个频带都需要透过滤波器将信号与其他频带分开&#xff0c;但目前用于手机的滤波器技术可能无法扩展到5G所规划的全部毫米波&#…

【HDFS入门】HDFS性能调优实战:压缩与编码技术深度解析

目录 1 HDFS性能调优概述 2 HDFS压缩技术原理与应用 2.1 常见压缩算法比较 2.2 压缩流程架构 2.3 压缩配置实践 3 列式存储编码技术 3.1 ORC与Parquet对比 3.2 ORC文件结构 3.3 Parquet编码流程 4 性能调优实战建议 4.1 压缩选择策略 4.2 编码优化技巧 5 性能测试…

如何在 IntelliJ IDEA 中安装通义灵码 - AI编程助手提升开发效率

随着人工智能技术的飞速发展&#xff0c;AI 编程助手已成为提升开发效率和代码质量的强大工具。在众多 AI 编程助手之中&#xff0c;阿里云推出的通义灵码凭借其智能代码补全、代码解释、生成单元测试等丰富功能&#xff0c;脱颖而出&#xff0c;为开发者带来了全新的编程体验。…

从零到一:管理系统设计新手如何快速上手?

管理系统设计是一项复杂而富有挑战性的任务&#xff0c;它要求设计者具备多方面的知识和技能&#xff0c;包括需求分析、架构设计、数据管理、用户界面设计等。对于初次接触这一领域的新手而言&#xff0c;如何快速上手并成为一名合格的管理系统设计者呢&#xff1f;本文将从管…

WSL (ext4.vhdx文件)占用空间过大,清理方式记录,同时更改 WSL 保存位置

一、问题 之前使用 WSL Ubuntu 进行过开发板的 Yocto 项目编译&#xff0c;占用空间达到了 70GB 多的空间。后来进行了项目迁移&#xff0c;删除了 WSL 中的所有文件&#xff0c;但是从 Windows 查看空间占用却没有减少&#xff1a; 占用依然是 70 多&#xff0c;查阅发现 vhdx…

《软件设计师》复习笔记(14.2)——统一建模语言UML、事务关系图

目录 1. UML概述 2. UML构造块 (1) 事物&#xff08;Things&#xff09; (2) 关系&#xff08;Relationships&#xff09; 真题示例&#xff1a; 3. UML图分类 (1) 结构图&#xff08;静态&#xff09; (2) 行为图&#xff08;动态&#xff09; 4. 核心UML图详解 5.…

[文献阅读] EnCodec - High Fidelity Neural Audio Compression

[文献信息]&#xff1a;[2210.13438] High Fidelity Neural Audio Compression facebook团队提出的一个用于高质量音频高效压缩的模型&#xff0c;称为EnCodec。Encodec是VALL-E的重要前置工作&#xff0c;正是Encodec的压缩量化使得VALL-E能够出现&#xff0c;把语音领域带向大…

【操作系统原理01】操作系统引论

文章目录 大纲一、中断与异常0.大纲1. 中断的作用2. 中断类型2.1 内中断2.2 外中断2.3 判断内外中断 3. 中断机制原理 二、系统调用0. 大纲1.什么是系统调用2.系统调用分类 三、操作性系统内核(了解)0.大纲1.内核2.各种操作系统结构特性 四、操作系统引论0.大纲1.磁盘存储 图片…

最新得物小程序sign签名加密,请求参数解密,响应数据解密逆向分析

点击精选&#xff0c;出现https://app.dewu.com/api/v1/h5/index/fire/index 这个请求 直接搜索sign的话不容易定位 直接搜newAdvForH5就一个&#xff0c;进去再搜sign&#xff0c;打上断点 可以看到t.params就是没有sign的请求参数&#xff0c; 经过Object(a.default)该函数…

Day2—3:前端项目uniapp壁纸实战

接下来我们做一个专题精选 <view class"theme"><common-title><template #name>专题精选</template><template #custom><navigator url"" class"more">More</navigator></template></common…

Python基于知识图谱的医疗问答系统【附源码、文档说明】

博主介绍&#xff1a;✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…

股指期货跨期套利是如何赚取价差利润的?

股指期货跨期套利&#xff0c;简单来说&#xff0c;就是在同一交易所内&#xff0c;针对同一股指期货品种的不同交割月份合约进行的套利交易。投资者会同时买入某一月份的股指期货合约&#xff0c;并卖出另一月份的股指期货合约&#xff0c;待未来某个时间点&#xff0c;再将这…

w297毕业生实习与就业管理系统

&#x1f64a;作者简介&#xff1a;多年一线开发工作经验&#xff0c;原创团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339;赠送计算机毕业设计600个选题excel文…

Java集合框架中的List、Map、Set详解

在Java开发中&#xff0c;集合框架是处理数据时不可或缺的工具之一。今天&#xff0c;我们来深入了解一下Java集合框架中的List、Map和Set&#xff0c;并探讨它们的常见方法操作。 目录 一、List集合 1.1 List集合介绍 1.2 List集合的常见方法 添加元素 获取元素 修改元素…