Tensorflow2.0笔记 - 使用compile,fit,evaluate,predict简化流程

news2024/9/30 23:19:36

        本笔记主要用compile, fit, evalutate和predict来简化整体代码,使用这些高层API可以减少很多重复代码。具体内容请自行百度,本笔记基于FashionMnist的训练笔记,原始笔记如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读347次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502        

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__

#加载fashion mnist数据集
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)
    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)
    with gzip.open(images_path, 'rb') as imgpath:
        
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)
    return images, labels

#预处理数据
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)
    x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.convert_to_tensor(y, dtype=tf.int32)
    return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)

batch_size = 128

train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)

train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)

test_db_iter = iter(test_db)
sample = next(test_db_iter)
print(sample[0].shape)
print(sample[1].shape)


#定义网络模型
model = Sequential([
    #Layer 1: [b, 784] => [b, 256]
    layers.Dense(256, activation=tf.nn.relu),
    #Layer 2: [b, 256] => [b, 128]
    layers.Dense(128, activation=tf.nn.relu),
    #Layer 3: [b, 128] => [b, 64]
    layers.Dense(64, activation=tf.nn.relu),
    #Layer 4: [b, 64] => [b, 32]
    layers.Dense(32, activation=tf.nn.relu),
    #Layer 5: [b, 32] => [b, 10], 输出类别结果
    layers.Dense(10)
])
print("\n=====Building Model=========\n")
model.build(input_shape=(None, 28*28))
model.summary()

total_epoches = 10
learn_rate = 0.01
#编译网络,使用compile(),传入优化器,损失函数和metrics度量
#https://blog.csdn.net/weixin_48169169/article/details/120793534
print("\n=====Compiling Model=========\n")
model.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),
             loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['Accuracy'])
#使用fit()进行训练
print("\n=====Fitting Model=========\n")
model.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=2)
time.sleep(1)
#使用evaluate()进行模型验证,这里使用了test_db,实际中可以使用另外的数据集
print("\n=====Evaluating Model=========\n")
model.evaluate(test_db)
#使用predict()做数据预测
sample = next(iter(test_db))
#获得数据和标签
real_data = sample[0]
real_label = sample[1]
pred = model.predict(real_data)
pred = tf.argmax(pred, axis=1)
print("Predicted Labels:", pred.numpy())
print("Actual Labels:", real_label.numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1549997.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

六、保持长期高效的七个法则(二)Rules for Staying Productive Long-Term(2)

Rule #5 - If your work changes, your system should too. 准则五:如果你的工作变了,你的系统也应该改变。 For some, work will be consistent enough to not need major changes.You simply stick to the same system and you’ll get the results y…

PL/SQL概述

oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 PL/SQL概述 PL/SQL(Procedural Language extension to SQL)是 Oracle 对标准 SQL语言的扩充,是专门用于各种环境下对 Oracle 数据库进行访问和开发的语言。 由…

HWOD:对n个字符串按照字典序排序

一、知识点 1、pow函数 引用头文件math.h 求x的y次方 2、链接数学库 math.h头文件对应的库名称是libm sudo find / -name libm.so -print ls /usr/lib/x86_64-linux-gnu/ 链接命令:gcc xxx.c -L. -lm 3、52进制 A的ASCII码是65,Z的ASCII…

Ubuntu 配置 kubernetes 学习环境,让外部访问 dashboard

Ubuntu 配置 kubernetes 学习环境 一、安装 1. minikube 首先下载一下 minikube,这是一个单机版的 k8s,只需要有容器环境就可以轻松启动和学习 k8s。 首先你需要有Docker、QEMU、Hyperkit等其中之一的容器环境,以下使用 docker 进行。 对…

OpenCV模块熟悉:点云处理相关

1. 显示--VIZ 曾经基于PCL 做过不少点云相关的开发,采样VTK进行有点云显示。后来基于OpenCV做了不少三维重建工作,总是将点云保存下来,然后借助CloudCompare等查看结果。如果能够将VIZ编译进来,预计会提升开发速度。 …

aws 入门篇 02.区域和可用区

aws入门篇 02.Region和AZ 02.区域和可用区 区域(Region):us-east-1:美东1区可用区(Availability Zones) AWS的区域遍布世界各地 一个区域(Region)是由多个可用区(AZ&am…

从根本上优雅地解决 VSCode 中的 Python 模块导入问题

整体概述: 在我尝试运行 test_deal_file.py 时,我遇到了一个 ModuleNotFoundError 错误,Python告诉我找不到名为 controllers 的模块。这意味着我无法从 deal_file.py 中导入 read_excel 函数。 为了解决这个问题,我尝试了几种方法…

Sublime for Mac 使用插件Terminus

1. 快捷键打开命令面板 commandshiftp2. 选择 Package Control: Install Package,然后会出现安装包的列表 3. 在安装终端插件前,我们先装个汉化包,ChineseLocallization,安装完重启 4. 输入 terminus,选择第一个&am…

Pillow教程06:将图片中出现的黄色和红色,改成绿色

---------------Pillow教程集合--------------- Python项目18:使用Pillow模块,随机生成4位数的图片验证码 Python教程93:初识Pillow模块(创建Image对象查看属性图片的保存与缩放) Pillow教程02:图片的裁…

nvm安装以后,node -v npm 等命令提示不是内部或外部命令

因为有vue2和vue3项目多种,所以为了适应各类版本node,使用nvm管理多种node版本,但是当我按教程安装nvm以后,nvm安装以后,node -v npm 等命令提示不是内部或外部命令 首先nvm官网网址:https://github.com/coreybutler/…

iOS - Runtime-消息机制-objc_msgSend()

iOS - Runtime-消息机制-objc_msgSend() 前言 本章主要介绍消息机制-objc_msgSend的执行流程,分为消息发送、动态方法解析、消息转发三个阶段,每个阶段可以做什么。还介绍了super的本质是什么,如何调用的 1. objc_msgSend执行流程 OC中的…

接口自动化之 + Jenkins + Allure报告生成 + 企微消息通知推送

接口自动化之 Jenkins Allure报告生成 企微消息通知推送 在jenkins上部署好项目,构建成功后,希望可以把生成的报告,以及结果统计发送至企微。 效果图: 实现如下。 1、生成allure报告 a. 首先在Jenkins插件管理中&#x…

sqlite跨数据库复制表

1.方法1 要将 SQLite 数据库中的一个表复制到另一个数据库,您可以按照以下步骤操作: 备份原始表的SQL定义和数据: 使用 sqlite3 命令行工具或任何SQLite图形界面工具,您可以执行以下SQL命令来导出表的SQL定义和数据&#xff1a…

libVLC 视频抓图

Windows操作系统提供了多种便捷的截图方式,常见的有以下几种: 全屏截图:通过按下PrtSc键(Print Screen),可以截取整个屏幕的内容。截取的图像会保存在剪贴板中,可以通过CtrlV粘贴到图片编辑工具…

Machine Learning机器学习之K近邻算法(K-Nearest Neighbors,KNN)

目录 前言 背景介绍: 思想: 原理: KNN算法关键问题 一、构建KNN算法 总结: 博主介绍:✌专注于前后端、机器学习、人工智能应用领域开发的优质创作者、秉着互联网精神开源贡献精神,答疑解惑、坚持优质作品共…

探索MongoDB:发展历程、优势与应用场景

一、MongoDB简介 MongoDB 始于 2007 年,由 Dwight Merriman、Eliot Horowitz 和 Kevin Ryan –(DoubleClick 的主理团队)共同创立。 DoubleClick 是一家互联网广告公司(现隶属于 Google),公司团队开发并利…

前端埋点全解及埋点SDK实现方式

一、什么是埋点 所谓“埋点”,是数据采集领域(尤其是用户行为数据采集领域)的术语,指的是针对特定用户行为或事件进行捕获、处理和发送的相关技术及其实施过程。比如用户某个icon点击次数、观看某个视频的时长等等。 埋点…

python-pytorch获取FashionMNIST实际图片标签数据集

在查看pytorch官方文档的时候,在这里链接中https://pytorch.org/tutorials/beginner/basics/data_tutorial.html的Creating a Custom Dataset for your files章节,有提到要自定义数据集,需要用到实际的图片和标签。 在网上找了半天没找到&a…

Jenkins用户角色权限管理

Jenkins作为一款强大的自动化构建与持续集成工具,用户角色权限管理是其功能体系中不可或缺的一环。有效的权限管理能确保项目的安全稳定,避免敏感信息泄露。 1、安装插件:Role-based Authorization Strategy 系统管理 > 插件管理 > 可…

java入门学习Day01

本篇文章主要是学会如何使用IDEA,和运行第一个java文件。 java环境安装:Windows下Java环境配置教程_windows java环境配置-CSDN博客 IDEA安装:IDEA 2023.2.5 最新激活码,注册码(亲测好用) - 异常教程 以上两个链接…