【机器学习】基于t-SNE的MNIST数据集可视化探索

news2025/3/16 22:55:39

 一、前言

        在机器学习和数据科学领域,高维数据的可视化是一个极具挑战但又至关重要的问题。高维数据难以直观地理解和分析,而有效的可视化方法能够帮助我们发现数据中的潜在结构、模式和关系。本文以经典的MNIST手写数字数据集为例,探讨如何利用t-分布随机邻域嵌入(t-SNE)这一强大的降维技术,将高维的图像数据降维到二维空间,并进行可视化展示。通过本文,我们将深入了解t-SNE的原理、算法步骤,以及如何在Python中实现并应用它,从而更好地理解和探索高维数据的内在特性。


二、技术与原理简介

        在深入探讨t-SNE之前,我们首先需要区分机器学习中的两大主要范畴:监督学习和无监督学习。

        1. 监督学习

        监督学习是指在已知输入数据和对应标签的情况下,训练模型学习输入与输出之间的映射关系。模型通过学习大量的带标签数据,能够对新的、未见过的数据进行预测或分类。常见的监督学习算法包括:

  • 线性回归: 用于预测连续型变量。
  • 逻辑回归: 用于分类问题。
  • 支持向量机 (SVM): 用于分类和回归问题,尤其擅长处理高维数据。
  • 决策树: 基于树状结构进行决策,易于理解和解释。
  • 随机森林: 集成多个决策树,提高预测的准确性和鲁棒性。
  • 神经网络: 模拟人脑神经元结构,能够学习复杂的非线性关系。

        2. 无监督学习

        无监督学习是指在没有标签的情况下,训练模型发现数据中的潜在结构和模式。模型通过分析数据的内在特征,能够进行聚类、降维、关联规则挖掘等任务。常见的无监督学习算法包括:

  • 聚类: 将数据划分为不同的簇,使得同一簇内的数据相似度较高,不同簇之间的数据相似度较低。常见的聚类算法包括K-means、层次聚类、DBSCAN等。
  • 降维: 将高维数据降维到低维空间,同时尽可能保留数据的关键信息。常见的降维算法包括主成分分析 (PCA)、t-SNE、UMAP等。
  • 关联规则挖掘: 发现数据中不同项之间的关联关系,例如购物篮分析。

        3. 监督学习与无监督学习的区别

        4. MNIST数据集简介

        MNIST (Modified National Institute of Standards and Technology database) 是一个经典的手写数字数据集,广泛应用于机器学习和深度学习领域。它包含60,000个训练样本和10,000个测试样本,每个样本都是一个28x28像素的灰度图像,代表0到9之间的手写数字。

        4.1 数据格式

        MNIST数据集通常以两种格式提供:

  • 图像格式: 每个样本都是一个图像文件,例如PNG或JPEG格式。
  • 数值格式: 每个样本都被转换为一个784维的向量,其中每个元素代表一个像素的灰度值 (0到255)。

       4.2 数据集特点

  • 规模适中: MNIST数据集的规模适中,既可以用于快速原型验证,又可以用于训练复杂的模型。
  • 易于获取: MNIST数据集可以从多个来源免费获取,例如Scikit-learn、TensorFlow等。
  • 广泛应用: MNIST数据集被广泛应用于各种机器学习和深度学习算法的评估和比较。

        5. t-SNE算法原理与数学推导

        5.1 算法核心思想

        t-SNE(t-Distributed Stochastic Neighbor Embedding)是一种非线性降维技术,通过以下步骤实现高维数据到低维空间的映射:

  1. 计算高维相似度:在原始空间中,计算每对样本间的相似度
  2. 构建低维嵌入空间:在目标空间(如2D)中,通过优化使相似度分布匹配
  3. 梯度下降优化:最小化两空间分布的KL散度

        5.2 数学公式详解

        5.2.1 高维相似度计算

        对原始空间中的样本对(𝑥𝑖,𝑥𝑗) ,定义条件概率:

其中𝜎𝑖 ​为高斯核带宽,通过二分查找确定以满足** perplexity **参数(控制邻域大小)。

        5.2.2 低维相似度建模

        在目标空间中,定义联合概率:

        采用t-分布(自由度为1的Student分布)以增强对异常值的鲁棒性。

        5.2.3 目标函数优化

        通过最小化KL散度实现分布匹配:

其中

        优化过程使用梯度下降:

        5.3 算法步骤流程

  1. 参数初始化:设置降维维度(如2D)、perplexity(通常5-50)、学习率等
  2. 高维相似度计算:为每个样本计算条件概率矩阵𝑃P
  3. 低维初始化:随机生成初始嵌入坐标𝑌Y
  4. 梯度下降优化:迭代更新𝑌Y以最小化KL散度
  5. 结果输出:返回低维坐标矩阵

三、代码详解

        本文的代码主要分为以下几个部分:

        1. 导入库

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import datasets
from sklearn import manifold
%matplotlib inline

说明

  • import matplotlib.pyplot as plt: 导入matplotlib库,用于绘制图像。
  • import numpy as np: 导入numpy库,用于进行数值计算。
  • import pandas as pd: 导入pandas库,用于数据处理。
  • import seaborn as sns: 导入seaborn库,用于数据可视化。
  • from sklearn import datasets: 导入sklearn库中的datasets模块,用于加载数据集。
  • from sklearn import manifold: 导入sklearn库中的manifold模块,用于降维。
  • %matplotlib inline: 在Jupyter Notebook中显示图像。

        2. 加载数据

# 加载数据
data = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
pixel_values, targets = data
targets = targets.astype(int)

# 将DataFrame转换为numpy数组以便更容易操作
# 如果pixel_values已经是numpy数组,这一步可以跳过
if isinstance(pixel_values, pd.DataFrame):
    pixel_values_array = pixel_values.values
else:
    pixel_values_array = pixel_values

说明

  • data = datasets.fetch_openml('mnist_784', version=1, return_X_y=True): 使用datasets.fetch_openml函数加载MNIST数据集。'mnist_784'表示数据集的名称,version=1表示数据集的版本,return_X_y=True表示返回输入数据和标签。
  • pixel_values, targets = data: 将返回的数据解包为pixel_valuestargetspixel_values包含图像的像素值,targets包含图像的标签。
  • targets = targets.astype(int): 将标签转换为整数类型。
  • if isinstance(pixel_values, pd.DataFrame):: 检查pixel_values是否为pandas DataFrame类型。
  • pixel_values_array = pixel_values.values: 如果pixel_values为pandas DataFrame类型,则将其转换为numpy数组。
  • else: pixel_values_array = pixel_values: 否则,直接使用pixel_values

        3. 显示单个图像

# 显示单个图像
single_image = pixel_values_array[1].reshape(28, 28)
plt.imshow(single_image, cmap='gray')

说明

  • single_image = pixel_values_array[1].reshape(28, 28): 选择第一个图像,并将其reshape为28x28的矩阵。
  • plt.imshow(single_image, cmap='gray'): 使用plt.imshow函数显示图像,cmap='gray'表示使用灰度颜色映射。

        4. t-SNE降维

# t-SNE降维
tsne = manifold.TSNE(n_components=2, random_state=42)
transformed_data = tsne.fit_transform(pixel_values_array[:3000])

说明

  • tsne = manifold.TSNE(n_components=2, random_state=42): 创建一个t-SNE对象。n_components=2表示将数据降维到二维空间,random_state=42表示设置随机种子,保证结果的可重复性。
  • transformed_data = tsne.fit_transform(pixel_values_array[:3000]): 使用fit_transform函数对数据进行降维。这里只使用了前3000个样本,因为t-SNE的计算复杂度较高。

        5. 创建DataFrame用于可视化

# 创建DataFrame用于可视化
tsne_df = pd.DataFrame(
    np.column_stack((transformed_data, targets[:3000])),
    columns=["x", "y", "targets"]
)
tsne_df.loc[:, "targets"] = tsne_df.targets.astype(int)

说明

  • tsne_df = pd.DataFrame(...): 创建一个pandas DataFrame对象,用于存储降维后的数据和标签。
  • np.column_stack((transformed_data, targets[:3000])): 将降维后的数据和标签按列拼接在一起。
  • columns=["x", "y", "targets"]: 设置DataFrame的列名。
  • tsne_df.loc[:, "targets"] = tsne_df.targets.astype(int): 将DataFrame中的标签转换为整数类型。

        6. 可视化

# 可视化
# 注意:在新版本的seaborn中,size参数已更改为height
try:
    grid = sns.FacetGrid(tsne_df, hue="targets", size=8)
except TypeError:
    grid = sns.FacetGrid(tsne_df, hue="targets", height=8)
    
grid.map(plt.scatter, "x", "y").add_legend()

说明

  • grid = sns.FacetGrid(tsne_df, hue="targets", size=8): 创建一个seaborn FacetGrid对象,用于可视化降维后的数据。hue="targets"表示使用标签作为颜色编码,size=8表示设置图像的大小。
  • grid.map(plt.scatter, "x", "y").add_legend(): 使用plt.scatter函数绘制散点图,"x""y"表示散点图的横坐标和纵坐标,add_legend()表示添加图例。

        7. 完整代码

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import datasets
from sklearn import manifold
%matplotlib inline

# 加载数据
data = datasets.fetch_openml('mnist_784', version=1, return_X_y=True)
pixel_values, targets = data
targets = targets.astype(int)

# 将DataFrame转换为numpy数组以便更容易操作
# 如果pixel_values已经是numpy数组,这一步可以跳过
if isinstance(pixel_values, pd.DataFrame):
    pixel_values_array = pixel_values.values
else:
    pixel_values_array = pixel_values

# 显示单个图像
single_image = pixel_values_array[1].reshape(28, 28)
plt.imshow(single_image, cmap='gray')

# t-SNE降维
tsne = manifold.TSNE(n_components=2, random_state=42)
transformed_data = tsne.fit_transform(pixel_values_array[:3000])

# 创建DataFrame用于可视化
tsne_df = pd.DataFrame(
    np.column_stack((transformed_data, targets[:3000])),
    columns=["x", "y", "targets"]
)
tsne_df.loc[:, "targets"] = tsne_df.targets.astype(int)

# 可视化
# 注意:在新版本的seaborn中,size参数已更改为height
try:
    grid = sns.FacetGrid(tsne_df, hue="targets", size=8)
except TypeError:
    grid = sns.FacetGrid(tsne_df, hue="targets", height=8)
    
grid.map(plt.scatter, "x", "y").add_legend()


四、总结与思考

        本文以MNIST数据集为例,详细介绍了如何使用t-SNE进行高维数据可视化。通过t-SNE降维,我们可以将784维的图像数据降维到二维空间,并在散点图上清晰地看到不同数字之间的分布情况。

        t-SNE是一种强大的降维技术,但也有一些局限性:

  • 计算复杂度高: t-SNE的计算复杂度为O(n^2),对于大规模数据集,计算时间会非常长。
  • 参数敏感: t-SNE的性能受到参数的影响,例如困惑度 (perplexity) 和学习率 (learning_rate)。
  • 全局结构失真: t-SNE主要关注局部结构,可能会导致全局结构失真。

        在实际应用中,我们需要根据具体情况选择合适的降维技术。对于大规模数据集,可以考虑使用PCA或UMAP等更高效的算法。对于需要保留全局结构的场景,可以考虑使用Isomap或LLE等算法。


【作者声明】

        本文内容基于作者对基于t-SNE的MNIST数据集可视化探索实现过程的实验与总结,所有数据和代码均为原创。文章中的观点仅代表个人见解,供读者参考交流。若有任何问题或建议,欢迎在评论区留言讨论,共同促进技术进步。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,欢迎点赞、收藏并转发,与更多朋友一起探讨与交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2316286.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Pycharm】Pycharm无法复制粘贴,提示系统剪贴板不可用

我也没有用vim的插件,检查了本地和ubutnu上都没有。区别是我是远程到ubutnu的pycharm,我本地直接控制windowes的pycharm是没问题的。现象是可以从外部复制到pycharm反之则不行。 ctl c ctlv 以及右键 都不行 参考:Pycharm无法复制粘贴&…

Flink-学习路线

最近想学习一下Flink,公司的实时需求还是不少的,因此结合ai整理了一份学习路线,记录一下。 当然,公司也有Scala版本Flink框架,也学习了一下。这里只说Java版本 1. Java基础 目标: 掌握Java编程语言的基础知识。 内容…

Atcoder ABC397-D 题解

https://atcoder.jp/contests/abc397/tasks/abc397_dhttps://atcoder.jp/contests/abc397/tasks/abc397_d 题目描述: 确定是否存在一对正整数,使得 思路: 首先对方程进行转化 设 即 接下来确定的范围 根据立方差公式 因此,我们可以从到来…

K8S学习之基础二十七:k8s中daemonset控制器

k8s中DaemonSet控制器 ​ DaemonSet控制器确保k8s集群中,所有节点都运行一个相同的pod,当node节点增加时,新节点也会自动创建一个pod,当node节点从集群移除,对应的pod也会自动删除。删除DaemonSet也会删除创建的pod。…

神经网络的基本知识

感知机 输入:来自其他 n 个神经元传递过来的输入信号 处理:输入信号通过带权重的连接进行传递, 神经元接受到总输入值将与神经元的阈值进行比较 输出:通过激活函数的处理以得到输出 感知机由两层神经元组成, 输入层接受外界输入信号传递给…

PostgreSQL技术内幕26:PG聚合算子实现分析

文章目录 0.简介1.概念说明2.朴素聚集3.Group by聚集3.1 哈希聚集3.2 分组聚集 0.简介 聚合算子在聚合函数在数据分析、报告生成和统计计算中扮演着重要角色,通过对多行数据进行计算,将多个输入值压缩为单一输出值,如求和、平均值、计数等。…

【RS】OneRec快手-生成式推荐模型

note 本文提出了一种名为 OneRec 的统一生成式推荐框架,旨在替代传统的多阶段排序策略,通过一个端到端的生成模型直接生成推荐结果。OneRec 的主要贡献包括: 编码器-解码器结构:采用稀疏混合专家(MoE)架构…

mac安装navicat及使用

0.删除旧的 sudo rm -Rf /Applications/Navicat\ Premium.app sudo rm -Rf /private/var/db/BootCaches/CB6F12B3-2C14-461E-B5A7-A8621B7FF130/app.com.prect.NavicatPremium.playlist sudo rm -Rf ~/Library/Caches/com.apple.helpd/SDMHelpData/Other/English/HelpSDMIndexF…

【HTML】二、列表、表格

文章目录 1、列表1.1 无序列表1.2 有序列表1.3 定义列表 2、表格2.1 定义2.2 表格结构标签2.3 合并单元格 1、列表 列表分为: 无序列表有序列表定义列表:一个标题下有多个小分类 1.1 无序列表 ul嵌套li,ul是无序列表,li是列表…

​​​​​​​大语言模型安全风险分析及相关解决方案

大语言模型的安全风险可以从多个维度进行分类。 从输入输出的角度来看,存在提示注入、不安全输出处理、恶意内容生成和幻觉错误等风险; 从数据层面来看,训练数据中毒、敏感信息泄露和模型反演攻击是主要威胁; 模型自身则面临拒绝服务和盗窃的风险; 供应链和插件的不安全引…

windows平台的ffmpeg编译使用

windows平台的ffmpeg编译使用 一、现状 本人使用libgdx开发galGame,发现扩展包gdx-video不支持mp4,不能忍,正好看到官网有支持自定义编译的文档,所以操作一下,自定义编译。本文重点在于操作windows平台,linux平台太简单了。 整个过程包括如下几个步骤。 二、代码下载…

FFMPEG录制远程监控摄像头MP4

手绘效果图 上图是录制功能的HTML前端页面,录制功能和解码视频放在一起。录制功能关键是录制(开始录制按钮)、停止录像按钮。当点击“录制”的时候则会开始录制MP4文件, 当点击停止的时候就会停止录制MP4。经过录制后,则会生成MP4,并放到我的RV1126的/tm…

centos操作系统上传和下载百度网盘内容

探序基因 整理 进入百度网盘官网百度网盘 客户端下载 下载linux的rpm格式的安装包 在linux命令行中输入:rpm -ivh baidunetdisk_4.17.7_x86_64.rpm 出现报错: 错误:依赖检测失败: libXScrnSaver 被 baidunetdisk-4.17.7-1.x8…

Rubick:基于 Electron 的开源插件化桌面效率工具箱

Rubick 是一款基于 Electron 构建的开源桌面工具箱,专为追求高效办公和个性化体验的用户设计。它通过自由集成丰富的插件,让用户能够根据自己的需求打造极致的桌面端效率工具。 软件命名由来Rubick 的名字来源于《DOTA2》中的英雄 Rubick(拉…

ruoyi-vue部署

ruoyi源码类型 Ruoyi源码 编译打包后,直接部署tomcat服务器 Ruoyi-vue 前后端分离版 前端部署到nginx 后端部署到tomcat RuoYi-Cloud 微服务版 RuoYi-app 移动端版 RuoYi-vue 前后端分离版 环境 JDK>=1.8 MySQL >= 5.7 Maven >= 3.0 Node >= 12 Redis…

LLM论文笔记 23: Meta Reasoning for Large Language Models

Arxiv日期:2024.6.17机构:THU / MSRA 关键词 meta-reasoning推理方法prompt engineering 核心结论 1. 提出Meta Reasoning prompting,MRP是一种系统提示方法,能够帮助LLM动态选择最合适的推理方法,从而提升其灵活性和…

【最后203篇系列】015 几种消息队列的思考

背景 队列还是非常重要的中间件,可以帮助我们:提高处理效率、完成更复杂的处理流程 最初,我觉得只要掌握一种消息队列就够了,现在想想挺好笑的。 过去的探索 因为我用python,而rabbitmq比较贴合快速和复杂的数据处…

学习springboot 的自动配置原理

前言 为什么要学习springboot 的自动配置原理? 1学习 自定义成starter 的前提 实际开发中,我们如果定义公共的组件给团队使用,为了让他们使用方便就自定义成starter。而想要学习starter ,就要先了解springboot 的自动配置原理 2 面试需要 了…

排错 -- FISCO BCOS区块链网络 -- 3. 编译智能合约

文章为FISCO BCOS2.0搭建区块链平台中发现的问题与总结,出错原因不唯一 ,解决办法不唯一 目前社区缺少完整,稳定的搭建平台和教程 ,欢迎各位及时补充,如有错误请及时评论纠正! 感谢各位搜索到这里&#…

ffmpeg 添加毫秒时间戳

网上有好多添加时间水印的,默认是到秒,而我需要到毫秒,查了一下,没有找到更好的方案,下面是自己实现的方案,可以显示到毫秒。如果有更好的方案,欢迎讨论 ffmpeg -i video.mp4 -vf "drawte…