Tensorflow2.0笔记 - tensor排序操作

news2025/1/4 19:46:15

        本笔记主要记录sort,argsort,以及top_k操作,加上一个求Top K准确度的例子。

import tensorflow as tf
import numpy as np

tf.__version__


#sort,argsort

#对1维的tensor进行排序
tensor = tf.random.shuffle(tf.range(10))
print(tensor)
#升序
print("======tf.sort(direction='ASCENDING'):", tf.sort(tensor, direction='ASCENDING'))
#降序
print("======tf.sort(direction='DESCENDING'):", tf.sort(tensor, direction='DESCENDING'))
#argsort,返回排序后元素对应原始数据元素的index
print("======tf.argsort(direction='DESCENDING'):", tf.argsort(tensor, direction='DESCENDING'))
args = tf.argsort(tensor, direction='DESCENDING')
print("======Max element:", tensor[args[0]])

#多维tensor排序
tensor = tf.random.uniform([3,3], maxval=10, dtype=tf.int32)
print(tensor)

#不带参数,默认升序
print("======tf.sort():", tf.sort(tensor))
#降序
print("======tf.sort(direction='DESCENDING'):", tf.sort(tensor, direction='DESCENDING'))
#argsort
print("======tf.argsort(direction='DESCENDING'):", tf.argsort(tensor, direction='DESCENDING'))

#top_k得到前最大/最小值
tensor = tf.random.uniform([3,3], maxval=10, dtype=tf.int32)
print(tensor)
#top_k返回值主要有indices和values
#indices返回top k个元素的下标数据
#values返回top k个元素的值
#得到最大的前两个元素
topN = tf.math.top_k(tensor, 2)
print("=====Top 2 indices:", topN.indices)
print("=====Top 2 values :", topN.values)

#top-k accuracy
#假设下面的tensor表示各个类别的预测概率信息,真实的标签类别是2(下标)
# tensor = tf.convert_to_tensor([0.1, 0.2, 0.3, 0.4])
# 那么top-1是0.4,对应标签是3,真实标签类别是2,预测错误,top-1预测准确率是0%
# top-2表示返回前两个最有可能的值[0.4,0.3],对应标签是[3,2],top-2预测准确率100%
# 同理,top-3预测准确率100%

#举例说明
#假设下面的tensor为两个样本的预测结果
prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.65, 0.15]])
print("=====Probabilities:", prob)
#标签信息,第一个样本真实类别是2, 第二个样本真实类别是0
target = tf.constant([2, 0])

#使用top_k获得预测结果的indices,这个结果就是对应的类别信息
predictedClasses = tf.math.top_k(prob, 3).indices
predictedClasses = tf.transpose(predictedClasses, perm=[1, 0])
#转置后的矩阵,第一行表示两个个样本top 1的预测值(最有可能的类别),第二行表示top 2的预测值(第二可能的类别)
print(predictedClasses)
#将真实值broadcast_to一个3*2的矩阵(1x2 => 3x2)
target = tf.broadcast_to(target, [3,2])
print(target)

#接下来就可以对比preditecdClasses和target
#Predicted       Actual
#[2, 1]          [2,0]   => top1准确度: 1/2 = 50%
#[1, 0]          [2,0]   => top2准确度: 
#                           样本1(第一列前两个元素)和真实的target里有一个能对上,预测正确,计数1
#                           样本2(第2列前两个元素)和真实target的类别有一个能对上,预测正确,计数1
#                           最终结果是: 1+1 / 2(总样本数) = 100%
#[2, 1]          [2,0]   => top3准确度: 100%


#实例,返回topk的准确率函数
#output: 网络输出的预测概率结果,[b, N],batchsize个预测值
#target: 真实的类别,[b]
#topk: 表示要返回哪些topk结果,假设topk = [1, 2, 3],表示要返回top1, top2和top3三个准确度结果
def topKAccuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batchSize = target.shape[0]

    pred = tf.math.top_k(output, maxk).indices
    pred = tf.transpose(pred, perm=[1, 0])
    real = tf.broadcast_to(target, pred.shape)
    #转换为0(False),1(True)二值表示的结果
    correct = tf.equal(pred, real) #correct是一个[k, b]大小的tensor

    result = []
    for k in topk:
        #取出前k行求和除以样本数量
        #取出前k行用reshape进行flatten
        correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
        #求和
        correct_k = tf.reduce_sum(correct_k)
        accuracy = float(correct_k / batchSize)
        result.append(accuracy)
    return result

#模拟一个10个样本,6个类别的预测结果
output = tf.random.normal([10, 6])
print("=====>Original Output:\n", output.numpy())
#softmax处理,让指定axis的数据转换成元素相加结果为1的数据(概率)
output = tf.math.softmax(output, axis=1)
print("=====>Probability(Softmax Output):\n", output.numpy())
print("=====>Argmax:\n", tf.argmax(output, axis=1).numpy())
#模拟一个真实类别信息,10个样本的真实标签
target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
print("=====>Labels:\n", target.numpy())

accuracies = topKAccuracy(output, target, topk=(1,2,3,4,5,6))
print("Top1 - Top6 Accuracy:\n", accuracies)

        运行结果:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1412873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

晨曦记账本,你的账目整理专家

传统的账目明细记录和整理方式不仅费时费力,还容易出错。在数字化时代,【晨曦记账本】为你提供了一个全新的账目明细整理模式。它不仅能让你的账目明细变得井井有条,更能让你在管理财务的过程中得心应手。 所需工具: 一个【晨曦…

Ubuntu18编译jdk8源码

环境 系统 ubuntu18 Linux ubuntu 5.4.0-150-generic #167~18.04.1-Ubuntu SMP Wed May 24 00:51:42 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux jdk源码openjdk-8u41-src-b04-14_jan_2020.zip bootJdk jdk-8u391-linux-x64.tar.gz ps -e|grep ssh sudo apt-get install ssh…

BabylonJS 6.0文档 Deep Dive 摄像机(六):遮罩层和多相机纹理

1. 使用遮罩层来处理多个摄影机和多网格物体 LayerMask是分配给每个网格(Mesh)和摄像机(Camera)的一个数。它用于位(bit)级别用来指示灯光和摄影机是否应照射或显示网格物体。默认值为0x0FFFFFFF&#xff…

语义分割(2) :自定义Dataset和Dataloader

文章目录 1. 数据处理1.1 标签转换(json2mask和json2yolo)1.1.1 json2mask1.1.2 json2yolo 1.2 划分数据集1.2 不规范的标签图片处理1.3 批量修改图片后缀 2 自定义Dataset 和 Dataloader2.1 自定义Dataset2.1.1 数据增强(1) 对图像进行缩放并且进行长和宽的扭曲(2) 随机翻转图…

【C++中STL】map/multimap容器

map/multimap容器 map基本概念map构造和赋值map的大小和交换map插入和删除map的查找和统计 map排序 map基本概念 map中的所有元素都是pair对组,高效率,pair中的第一个元素为key(键值),起到索引作用,第二个…

仅使用 Python 创建的 Web 应用程序(前端版本)第07章_商品列表

在本章中,我们将实现一个产品列表页面。 完成后的图像如下 创建过程与User相同,流程如下。 No分类内容1Model创建继承BaseDataModel的数据类Item2MockDB创建产品表并生成/添加虚拟数据3Service创建一个 ItemAPIClient4Page定义PageId并创建继承自BasePage的页面类5Applicati…

K8s-持久化(持久卷,卷申明,StorageClass,StatefulSet持久化)

POD 卷挂载 apiVersion: v1 kind: Pod metadata:name: random-number spec:containers:- image: alpinename: alpinecommand: ["/bin/sh","-c"]args: ["shuf -i 0-100 -n 1 >> /opt/number.out;"]volumeMounts:- mountPath: /optname: da…

Halcon指定区域的形状匹配

Halcon指定区域的形状匹配 文章目录 Halcon指定区域的形状匹配1.在参考图像中选择目标2.创建模板3.搜索目标 在这个实例中,会介绍如何根据选定的ROI选择合适的图像金字塔参数,创建包含这个区域的形状模板,并进行精确的基于形状模板的匹配。最…

08 BGP 华为官方文档 十一条选路原则

BGP 华为官方文档 十一条选路原则 丢弃下一跳不可达的路由 1)比较“协议首选值-pref-val”属性,数值越大越好,默认值是0,只在本设备生效,不在网络中传递 2)比较“本地优先级-local_pref”属性,…

初识数据库

数据库技术的基础术语 在学习数据库技术之前,我们先认识与该技术密切相关的基本术语,分别是数据库 (Database, DB)、数据库管理系统(Database Management System, DBMS)和数据库 系统(Database System, DBS),具体介绍如下。 1. 数据库 数…

win10安装redis并配置加自启动(采用官方推荐unix子系统)

记录,为啥有msi安装包,还这么麻烦的用linux版本redis的安装方式,是因为从github上下载别人制作的msi报毒,还不止一处,这种链接数据库的东西,用别人加工过的,都报毒了还用就是傻逼了。 所以采用…

13.while条件循环语句 (4)

while条件循环语句是一种让脚本根据某些条件来重复执行命令的语句,它的循环结构往往在执行前并不确定最终执行的次数,完全不同于for循环语句中有目标、有范围的使用场景。 while循环语句通过判断条件测试的真假来决定是否继续执行命令,若条件…

IP 地址如何进行动态分配?

由于 IP 地址资源的有限性,大部分用户上网都是使用动态 IP 地址,而不是静态 IP 地址。动态 IP 地址指的是在需要的时候才进行 IP 地址分配的方式,而静态 IP 地址是固定分配一个 IP 地址,每次都用这一个地址。因此,IP 地…

AI PC的引擎 – 英特尔第 14 代处理器Meteor Lake架构分析

英特尔从2023年12月开始在笔记本电脑中发售首款 Meteor Lake 第 14 代 Core Ultra 芯片,开启新的“AI PC”时代。这款芯片采用了全新的架构,将CPU分为四块“瓷砖”(tiles):CPU Tile,SoC Tile,Gr…

Python学习从0到1 day9 Python函数

苦难是花开的伏笔 ——24.1.25 函数 1.定义 函数:是组织好的,可重复使用的,用来实现特定功能的代码段 2.案例 在pycharm中完成一个案例需求:不使用内置函数len(),完成字符串长度的计算 #统计字…

海外云手机三大优势

在全球化潮流下,企业因业务需求对海外手机卡等设备的需求不断攀升,推动了海外云手机业务的蓬勃发展。相较于自行置备手机设备,海外云手机不仅能够降低成本,还具备诸多优势,让我们深入探讨其中的三大黄金优势。 经济实惠…

Vulnhub靶机:FunBox 9

一、介绍 运行环境:Virtualbox 攻击机:kali(10.0.2.15) 靶机:FunBox 9(10.0.2.37) 目标:获取靶机root权限和flag 靶机下载地址:https://www.vulnhub.com/entry/funb…

NGINX如何实现rtmp推流服务

最近直播大火,直播推流软件遍地开花,那么用NGINX如何进行推流呢?下面我们就简单的介绍一下用NGINX的rtmp模块如何实现视频推流,我们主要从一下几点介绍: 推流拉流推流认证拉流认证 package mainimport ("fmt&qu…

Elasticsearch基础篇(八):常用查询以及使用Java Api Client进行检索

ES常用查询以及使用Java Api Client进行检索 1. 检索需求 参照豆瓣阅读的列表页面 需求: 检索词需要在数据库中的题名、作者和摘要字段进行检索并进行高亮标红返回的检索结果需要根据综合、热度最高、最近更新、销量最高、好评最多进行排序分页数量为10&#xf…

vscode 代码格式化很短就换行,以及缩放设置

安装vetur 打开vscode设置settings.json { "editor.tabSize": 2,//缩进单位"vetur.format.defaultFormatter.html": "js-beautify-html","vetur.format.defaultFormatterOptions": {"js-beautify-html": {"wrap_line…