t-SNE进行分类可视化

news2025/1/18 17:15:05

0、引入

我们在论文中通常可以看到下图这样的可视化效果,这就是使用t-SNE降维方法进行的可视化,当然除了t-SNE还有其他的比如PCA等降维等方法,关于这些算法的原理有很多文章可以借阅,这里不展开阐释,重点讲讲如何进行可视化。
在这里插入图片描述

1、基本原理

上面的图中一个点就是一个样本,我们需要明白的是一个样本用两个数值表示(x和y坐标),意味着原来高维的样本被降维到低维(2维)的空间中了。

比如在将一个样本图片输入到VGG网络中,在倒数第二了全连接层有4096个神经元,也就是该样本使用了4096维的向量表示。我们获取到这个向量表示后通过t-SNE进行降维,得到2维的向量表示,我们就可以在平面图中画出该点的位置。

我们清楚同一类的样本,它们的4096维向量是有相似性的,并且降维到2维后也是具有相似性的,所以在2维平面上面它们会倾向聚拢在一起。

可视化的过程中,大概步骤就是:

  • 对于500张图片(样本)进行模型推理,获取倒数第二层的4096特征向量,同时获取到标签,至此我们有两个值data_embed=[500,4096],label=[500,]
  • 使用t-SNE降维,将data_embed降维后,每个样本用2维表示,即data_embed=[500,2]
  • 使用plt将data_embed共500个样本绘制,并且同一类的颜色一致。

2、收集模型中的高维向量表示和标签

这一步主要是收集每个样本的高维特征及标签,高维特征是全连接层的输出,所以需要通过推理模型获取到(需要修改模型使同时输出全连接层输出)。

data_embed_collect=[]
label_collect=[]

for ......
	# inputs.shape=[BS,C,H,W]
	# embed_4096.shape=[BS,4096]
	# output.shape=[BS,1000]
	output,embed_4096=model(inputs)
	
	
	data_embed_collect.append(embed_4096)
	label_collect.append(label)
	......


# data_embed_collect.shape=[iters,BS,4096]
# label_collect.shape=[iters,BS,]

# 在这里,所有样本的4096特征都收集了,并且每个样本的标签也收集了
# data_embed_npy.shape=[samples,4096]
# label_npu.shape=[samples,]
data_embed_npy=torch.cat(data_embed_collect,axis=0).cpu().numpy()
label_npu=torch.cat(label_collect,axis=0).cpu().numpy()

np.save("data_embed_npy.npy",data_embed_npy)
np.save("label_npu.npy",label_npu).

3、进行t-SNE降维并可视化

代码也简单,首先调用t-SNE进行降维降到2维,然后使用plt将2维定位坐标进行绘制,代码如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


def get_fer_data(data_path="vis_fer_data.npy",
                 label_path="vis_fer_label.npy"):
    """
	该函数读取上一步保存的两个npy文件,返回data和label数据
    Args:
        data_path:
        label_path:

    Returns:
        data: 样本特征数据,shape=(BS,embed)
        label: 样本标签数据,shape=(BS,)
        n_samples :样本个数
        n_features:样本的特征维度

    """
    data = np.load(data_path)
    label = np.load(label_path)
    n_samples, n_features = data.shape

    return data, label, n_samples, n_features

color_map = ['r','y','k','g','b','m','c'] # 7个类,准备7种颜色
def plot_embedding_2D(data, label, title):
	"""
	
	"""
    x_min, x_max = np.min(data, 0), np.max(data, 0)
    data = (data - x_min) / (x_max - x_min)
    fig = plt.figure()
    for i in range(data.shape[0]):
        plt.plot(data[i, 0], data[i, 1],marker='o',markersize=1,color=color_map[label[i]])
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    return fig


def main():
    data, label, n_samples, n_features = get_fer_data()  # 根据自己的路径合理更改

    print('Begining......') 	

	# 调用t-SNE对高维的data进行降维,得到的2维的result_2D,shape=(samples,2)
    tsne_2D = TSNE(n_components=2, init='pca', random_state=0) 
    result_2D = tsne_2D.fit_transform(data)
    
    print('Finished......')
    fig1 = plot_embedding_2D(result_2D, label, 't-SNE')	# 将二维数据用plt绘制出来
    fig1.show()
    plt.pause(50)
    
if __name__ == '__main__':
    main()

7个类,共计3589个样本降维到2维后的效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/415003.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

window10 更新提示 0x80073712错误

解决方法: 1、可以尝试重新配置一下 Windows 更新服务状态: 2、Win S打开搜索,输入 CMD 找到 “命令提示符”, 3、右键以管理员身份打开,依次输入以下代码,并按回车执行。注:是一条一条的执行…

vue基础知识

1、特点 1.采用组件化模式,提高代码复用率、且让代码更好维护。 2.声明式编码,让编码人员无需直接操作DOM,提高开发效率 命令式编码 3.使用虚拟DOM优秀的Diff 算法,尽量复用DOM节点。 2、hello vue vue的引入 就是写在引入c…

关键词词库制作-搜索词分析工具

关键词词库制作 关键词词库是一种帮助SEO和SEM优化的工具,它可以帮助您确定关键词的流行程度、竞争程度、搜索意图和其他相关信息等等。以下是一些关键词词库制作的方法: 收集关键词:首先需要收集相关的关键词,这可能涉及到您的业…

一文讲透产品经理如何用好ChatGPT

作者:京东零售 何雨航 “4.0版本的ChatGPT可以有效提升产品经理工作效率,但并无法替代产品经理的角色。” 一、引言 3月15日,OpenAI发布了最新的基于GPT-4的ChatGPT,关于其智能性的讨论热度在互联网上空前高涨。 我之前体验过3…

基于POSIX的消息队列的发送、接收demo的设计(linux)

本文介绍POSIX的消息队列的linux应用,新建两个进程(一个发送进程、一个接收进程)实现消息形式的数据传输。POSIX消息队列与SystemV消息队列存在相似的消息传输单位,但较SystemV消息队列更适合linux系统的使用。本文在ubuntu20.4上…

面试篇-深入理解 Java 中的 HashMap 实现原理

一、HashMap实现原理 HashMap 的实现主要包括两个部分:哈希函数和解决哈希冲突的方法。 1.哈希函数 当使用 put() 方法将键值对存储在 HashMap 中时,首先需要计算键的哈希值。HashMap 使用 hashCode() 方法获取键的哈希值,并将其转换为桶&…

Docker的常见命令

前言:使用Docker得学会的几个常见命令 常见命令前置学习: docker --help这个命令必须得会因为,很多命令是记不住的,得使用他们的官方help下面是一些实例 docker load --help常见命令集合: 一: docker images #查看全部镜像 docker rmi #删除某个镜像(例如:docker rmi redis…

Vue3——组件间通信的五种常用方式

Vue3组件间通信的五种常用方式 写在前面 本文采用<script setup>语法糖的编写方式&#xff0c;比options API更自由。 <script setup>语法糖详细内容看查看文档&#xff1a;setup语法糖官方文档 然后我们会讲以下五种常用的组件通信方式 propsemitv-modelrefs…

高速数字信号VS射频信号,到底哪个更难设计?

一博高速先生成员&#xff1a;黄刚熟悉高速先生的小伙伴们会知道&#xff0c;我们是以研究高速数字信号为主的团队&#xff0c;从不到1G到目前在研究的112G&#xff0c;高速先生就这样一直研究过来的&#xff0c;分享的案例也大多是以高速数字信号为主的案例。最近受到我们粉丝…

golang for range 令人抓狂的面试题

1.下面这段代码能否正常结束&#xff1f; func main() {v : []int{1, 2, 3}for i : range v {v append(v, i)} } 答案&#xff1a;正常结束。 可能我们会以为程序会陷入死循环。 但是我们要明白 for range 中的v其实就是复制了一份前面定义的v切片&#xff0c;不论前面定…

Python从入门到精通第3天(循环结构的使用)

循环结构for-in循环while循环break和continue关键字练习在写程序的时候&#xff0c;一定会遇到需要重复执行某条或某些指令的场景&#xff0c;例如用程序控制机器人踢足球&#xff0c;如果机器人持球而且还没有进射门范围&#xff0c;那么我们就要一直发出让机器人向球门方向移…

免费ChatGPT接入-国内怎么玩chatGPT

免费ChatGPT中文版 OpenAI 的 GPT 模型目前并不提供中文版的免费使用&#xff0c;但是有许多机器学习平台和第三方服务提供商也提供了基于 GPT 技术的中文版模型和 API。下面是一些常见的免费中文版 ChatGPT&#xff1a; Hugging Face&#xff1a;Hugging Face 是一个开源社区…

JAVAWeb03-JavaScript

1. JavaScript 1.1 概述 1.1.1 官方文档 地址: https://www.w3school.com.cn/js/index.asp 1.1.2 基本说明 JavaScript 能改变 HTML 内容&#xff0c;能改变 HTML 属性&#xff0c;能改变 HTML 样式 (CSS)&#xff0c;能完成页面的数据验证。 js演示1.html 需要把图片拷贝…

一个注解实现WebSocket集群方案,别提有多优雅了

WebSocket大家应该是再熟悉不过了&#xff0c;如果是单体应用确实不会有什么问题&#xff0c;但是当我们的项目使用微服务架构时&#xff0c;就可能会存在问题 比如服务A有两个实例A1和A2&#xff0c;前端的WebSocket客户端C通过网关的负载均衡连到了A1&#xff0c;这个时候当…

【Java数据结构】线性表-队列

线性表-队列概念队列的使用队列模拟实现循环队列如何区分空与满双端队列 (Deque)概念 队列&#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;队列具有先进先出FIFO(FirstIn First Out) 入队列&#xff1a;进行插入操作的…

文章生成器写出来的原创文章

文章生成机器人 文章生成机器人是一种基于人工智能技术和自然语言处理算法的程序&#xff0c;可以自动地生成高质量、原创的文章。 文章生成机器人的优点如下&#xff1a; 提高工作效率&#xff1a;文章生成机器人能够在较短的时间内自动帮助用户生成大量的文章&#xff0c;提…

GaussDB工作级开发者认证—第三章开发设计建议

一. 数据库对象命名和设计建议 二. 表设计最佳实践 三. SQL查询最佳实践 SQL 最佳实践 - SELECT 避免对大字段执行order by&#xff0c;group by等引起排序的操作避免频繁使用count()获取大表行数慎用通配符字段 “*”避免在select目标列中使用子查询统计表中所有记录数时&…

设计模式之策略模式(C++)

作者&#xff1a;翟天保Steven 版权声明&#xff1a;著作权归作者所有&#xff0c;商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处 一、策略模式是什么&#xff1f; 策略模式是一种行为型的软件设计模式&#xff0c;针对某个行为&#xff0c;在不同的应用场景下&…

win下配置pytorch3d

一、配置好的环境&#xff1a;py 3.9 pytorch 1.8.0 cuda 11.1_cudnn 8_0 pytorch3d 0.6.0 CUB 1.11.0 你可能觉得pytorch3d 0.6.0版本有点低&#xff0c;但是折腾不如先配上用了&#xff0c;以后有需要再说。 &#xff08;后话&#xff1a;py 3.9 pytorch 1.12.1 cuda …

Log4j日志

log4j日志简介组成Logger 日志记录器Appender 日志目的地&#xff08;Windows下的路径分隔符&#xff09;※Layout 日志信息布局layout 指定输出的样式模板&#xff1f;layout.ConversionPattern 指定输出的每项内容及其格式顺序日志信息等级/优先级使用的Log4j的jar包代码示例…