Tensorflow2.0笔记 - metrics做损失和准确度信息度量

news2024/11/23 19:09:12

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__


#加载fashion mnist数据集
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)
    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)
    with gzip.open(images_path, 'rb') as imgpath:
        
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)
    return images, labels

#预处理数据
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)
    x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.convert_to_tensor(y, dtype=tf.int32)
    return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)

batch_size = 128

train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)

train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)

#定义网络模型
model = Sequential([
    #Layer 1: [b, 784] => [b, 256]
    layers.Dense(256, activation=tf.nn.relu),
    #Layer 2: [b, 256] => [b, 128]
    layers.Dense(128, activation=tf.nn.relu),
    #Layer 3: [b, 128] => [b, 64]
    layers.Dense(64, activation=tf.nn.relu),
    #Layer 4: [b, 64] => [b, 32]
    layers.Dense(32, activation=tf.nn.relu),
    #Layer 5: [b, 32] => [b, 10], 输出类别结果
    layers.Dense(10)
])

#编译网络
model.build(input_shape=[None, 28*28])
model.summary()

#进行训练
total_epoches = 5
learn_rate = 0.01

#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()


#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()

optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):
    for step, (x,y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            logits = model(x)
            y_onehot = tf.one_hot(y, depth=10)
            #使用交叉熵作为loss
            loss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))
            #调用update_state更新loss度量信息
            loss_meter.update_state(loss_ce)
        #计算梯度
        grads = tape.gradient(loss_ce, model.trainable_variables)
        #更新梯度
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())
            loss_meter.reset_states()

    #使用测试集进行验证
    total_correct = 0
    total_num = 0
    #清除准确度的统计信息
    acc_meter.reset_states()
    for x,y in test_db:
        logits = model(x)
        #使用softmax得到各个类别的概率
        prob = tf.nn.softmax(logits, axis=1)
        #求出概率最大的结果参数位置,作为预测的分类结果
        pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
        #比较结果
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
        #计算精度
        total_correct += int(correct)
        total_num += x.shape[0]
        #使用metircs的update_state进行更新
        acc_meter.update_state(y, pred)
    
    acc = total_correct / total_num
    print("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1553446.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Codeforces Round 937 (Div. 4)

A. Stair, Peak, or Neither?&#xff08;模拟&#xff09; #include<iostream> using namespace std;int main(){int t;scanf("%d", &t);int a, b, c;while(t--){scanf("%d%d%d", &a, &b, &c);if(a < b && b < c) p…

(免费分享)基于微信小程序电影院购票系统带论文

基于小程序的电影院购票管理系统【含报告】&#xff1a;前端 vue、elementui、小程序&#xff0c;后端 maven、springboot、springmvc、spring、mybatis&#xff0c;角色分为管理员、用户&#xff1b;集成小程序浏览电影&#xff0c;购票等功能于一体的系统。 目录 摘要 I Ab…

fpga 通过axi master读写PS侧DDR的仿真和上板测试

FPGA和ARM数据交互是ZYNQ系统中非常重要的内容。PS提供了供FPGA读写的AXI-HP接口用于两者的高速通信和数据交互。一般的&#xff0c;我们会采用AXI DMA的方式去传输数据&#xff0c;DMA代码基本是是C编写&#xff0c;对于FPGA开发者来说不利于维护和debug。本文提供一种手写AXI…

系统分析师-数学与经济管理

系统架构设计师 系统架构设计师-软件开发模型总结 文章目录 系统架构设计师前言一、最小生成树二、最短路径三、网络与最大流量四、不确定型决策 前言 数学是一种严谨、缜密的科学&#xff0c;学习应用数学知识&#xff0c;可以培养系统架构设计师的抽象思维能力和逻辑推理能…

【WebJs 爬虫】逆向进阶技术必知必会

前言 在数字化时代&#xff0c;网络爬虫已成为一种强大的数据获取工具&#xff0c;广泛应用于市场分析、竞争对手研究、舆情监测等众多领域。爬虫技术能够帮助我们快速、准确地获取网络上的海量信息&#xff0c;为决策提供有力支持。然而&#xff0c;随着网络环境的日益复杂和…

基于vue的MOBA类游戏攻略分享平台的设计与实现|Springboot+Vue+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW&#xff0c;文末可获取本项目的所有资料。 推荐阅读100套最新项目持续更新中..... 2024年计算机毕业论文&#xff08;设计&#xff09;学生选题参考合集推荐收藏&#xff08;包含Springboot、jsp、ssmvue等技术项目合集&#xff09; 目录 1. …

istio 设置 istio-proxy sidecar 的 resource 的 limit 和 request

方式一 修改 configmap 查看当前 sidecar 的 cpu 和 memory 的配额 在 istio-sidecar-injector 中查找&#xff0c;修改后重启 pod 可以生效&#xff08;下面那个 proxy_init 配置不管&#xff0c;不知道是干嘛的&#xff09; 方式二 如果是通过 iop 安装的 istio&#xf…

连同网络的操作次数【并查集】

用以太网线缆将 n 台计算机连接成一个网络&#xff0c;计算机的编号从 0 到 n-1。线缆用 connections 表示&#xff0c;其中 connections[i] [a, b] 连接了计算机 a 和 b。 网络中的任何一台计算机都可以通过网络直接或者间接访问同一个网络中其他任意一台计算机。 给你这个…

那些激励你深入研究技术的语录

遇到耗时高的问题&#xff0c;90%能用重启解决&#xff0c;但是不找到原因&#xff0c;问题一定会再次出现。 代码里的Bug就像房间里的老鼠&#xff0c;你不找到它&#xff0c;它就会一直捣乱。 代码的质量决定了程序的稳定性&#xff0c;而程序的稳定性则决定了业务的成败。 不…

python程序如何工作

随着人工智能时代的来临&#xff0c;python成为了人们学习编程的首先语言。那么&#xff0c;python程序怎么运行的?我们下面来介绍下。 python程序执行原理 我们都知道&#xff0c;使用C&#xff0c;C之类的编译性语言编写的程序&#xff0c;是需要从源文件转换成计算机使用…

SAP FIORI App

什么是SAP FIORI App SAP FIORI是SAP公司推出的一套用户体验(UX)设计系统。 SAP Fiori不仅仅是一个产品&#xff0c;而是一个用户体验和用户交互的标准。它提供了一套基于HTML5的应用程序&#xff0c;这些应用程序可以运行在所有桌面和移动平台上&#xff0c;包括Windows、Ma…

HarmonyOS 应用开发之ExtensionAbility组件

ExtensionAbility组件是基于特定场景&#xff08;例如服务卡片、输入法等&#xff09;提供的应用组件&#xff0c;以便满足更多的使用场景。 每一个具体场景对应一个 ExtensionAbilityType&#xff0c;开发者只能使用&#xff08;包括实现和访问&#xff09;系统已定义的类型。…

「JavaSE」Lambda表达式

&#x1f387;个人主页&#xff1a;Ice_Sugar_7 &#x1f387;所属专栏&#xff1a;快来卷Java啦 &#x1f387;欢迎点赞收藏加关注哦&#xff01; Lambda表达式 &#x1f349;简介&#x1f349;函数式接口&#x1f34c;注解 &#x1f349;语法&#x1f349;Lambda表达式的基本…

基于springboot实现海滨体育馆管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现海滨体育馆管理系统演示 摘要 本基于Spring Boot的海滨体育馆管理系统设计目标是实现海滨体育馆的信息化管理&#xff0c;提高管理效率&#xff0c;使得海滨体育馆管理工作规范化、高效化。 本文重点阐述了海滨体育馆管理系统的开发过程&#xff0c;以实际…

AJAX(一):初识AJAX、http协议、配置环境、发送AJAX请求、请求时的问题

一、什么是AJAX 1.AJAX 就是异步的JS和XML。通过AJAX 可以在浏览器中向服务器发送异步请求&#xff0c;最大的优势&#xff1a;无刷新获取数据。AJAX 不是新的编程语言&#xff0c;而是一种将现有的标准组合在一起使用的新方式。 2.XML 可扩展标记语言。XML被设计用来传输和…

基于Redis实现短信登录

1.7.3、整体访问流程 当注册完成后&#xff0c;用户去登录会去校验用户提交的手机号和验证码&#xff0c;是否一致&#xff0c;如果一致&#xff0c;则根据手机号查询用户信息&#xff0c;不存在则新建&#xff0c;最后将用户数据保存到redis&#xff0c;并且生成token作为red…

牛客小白月赛89(A,B,C,D,E,F)

比赛链接 官方视频讲解&#xff08;个人觉得讲的还是不错的&#xff09; 这把BC偏难&#xff0c;差点就不想做了&#xff0c;对小白杀伤力比较大。后面的题还算正常点。 A 伊甸之花 思路&#xff1a; 发现如果这个序列中最大值不为 k k k&#xff0c;我们可以把序列所有数…

Vivado Lab Edition

Vivado Lab Edition 是完整版 Vivado Design Suite 的独立安装版本 &#xff0c; 包含在生成比特流后对赛灵思 FPGA 进行编程和 调试所需的所有功能。通常适用于在如下实验室环境内进行编程和调试&#xff1a; 实验室环境中的机器所含磁盘空间、内存和连 接资源较少。Vivad…

B2902A是德科技B2902A精密型电源

181/2461/8938产品概述&#xff1a; Agilent B2902A 精密源/测量单元 (SMU) 是一款 2 通道、紧凑且经济高效的台式 SMU&#xff0c;能够源和测量电压和电流。它用途广泛&#xff0c;可以轻松、高精度地执行 I/V&#xff08;电流与电压&#xff09;测量。4 象限源和测量功能的集…

Android room 在dao中不能使用挂起suspend 否则会报错

错误&#xff1a; Type of the parameter must be a class annotated with Entity or a collection/array of it. kotlin.coroutines.Continuation<? super kotlin.Unit> $completion); 首先大家检查一下几个点 一、kotlin-kapt 二、 是否引入了 room-ktx 我是2024年…