Tensorflow2.0笔记 - FashionMnist数据集训练

news2025/1/11 17:58:28

        本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。

        本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。

        关于FashionMnist的介绍,请自行百度。

        

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__

#加载fashion mnist数据集
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)
    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)
    with gzip.open(images_path, 'rb') as imgpath:
        
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)
    return images, labels

#预处理数据
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)
    x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.convert_to_tensor(y, dtype=tf.int32)
    return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)

batch_size = 128

train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)

train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)


#定义网络模型
model = Sequential([
    #Layer 1: [b, 784] => [b, 256]
    layers.Dense(256, activation=tf.nn.relu),
    #Layer 2: [b, 256] => [b, 128]
    layers.Dense(128, activation=tf.nn.relu),
    #Layer 3: [b, 128] => [b, 64]
    layers.Dense(64, activation=tf.nn.relu),
    #Layer 4: [b, 64] => [b, 32]
    layers.Dense(32, activation=tf.nn.relu),
    #Layer 5: [b, 32] => [b, 10], 输出类别结果
    layers.Dense(10)
])

#编译网络
model.build(input_shape=[None, 28*28])
model.summary()

#进行训练
total_epoches = 30
learn_rate = 0.01

optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):
    for step, (x,y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            logits = model(x)
            y_onehot = tf.one_hot(y, depth=10)
            #使用交叉熵作为loss
            loss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))
        #计算梯度
        grads = tape.gradient(loss_ce, model.trainable_variables)
        #更新梯度
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print("Epoch[", epoch, "]: step-", step, "\tloss: CrossEntropy-", loss_ce.numpy())

#使用测试集进行验证
total_correct = 0
total_num = 0
for x,y in test_db:
    logits = model(x)
    #使用softmax得到各个类别的概率
    prob = tf.nn.softmax(logits, axis=1)
    #求出概率最大的结果参数位置,作为预测的分类结果
    pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
    #比较结果
    correct = tf.equal(pred, y)
    correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
    #计算精度
    total_correct += int(correct)
    total_num += x.shape[0]

acc = total_correct / total_num
print("Accuracy:", acc)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1535206.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024开年首展,加速科技展台“热辣滚烫”

3月20日,备受瞩目的半导体行业盛会SEMICON China 2024在上海新国际博览中心盛大启幕,展会汇集了来自全球的半导体领域顶尖企业与专业人士。加速科技作为业界领先的半导体测试设备供应商携重磅测试设备及解决方案精彩亮相,展示了最新的半导体测…

陈巍:Sora大模型技术精要万字详解(上)——原理、关键技术、模型架构详解与应用

​目录 收起 1 Sora的技术特点与原理 1.1 技术特点概述 1.2 时间长度与时序一致性 1.3 真实世界物理状态模拟 1.4 Sora原理 1.4.1扩散模型与单帧图像的生成 1.4.2 Transformer模型与连续视频语义的生成 1.4.3 从文本输入到视频生成 2 Sora的关键技术 2.1 传统文生图技…

GitHub Copilot怎么取消付费?

0. 前言 GitHub Copilot非常好用,还没有使用过的同学可以参考教程白嫖一个月:【保姆级】VsCode 安装GitHub Copilot实操教程 GitHub Copilot每月10美元的费用对于一些用户来说可能是一笔不小的开销。如果你已经完成了GitHub Copilot的免费试用&#xf…

痛失offer的八股

java面试八股 mysql篇: 事物的性质: 事物的性质有acid四特性。 a:automic,原子性,要么全部成功,要么全部失败,mysql的undolog,事物在执行的时候,mysql会进行一个快照读…

C#多态性

文章目录 C#多态性静态多态性函数重载函数重载 动态多态性运行结果 C#多态性 静态多态性 在编译时,函数和对象的连接机制被称为早期绑定,也被称为静态绑定。C# 提供了两种技术来实现静态多态性。分别为: 函数重载 运算符重载 运算符重载将…

MINT: Detecting Fraudulent Behaviors from Time-series Relational Data论文阅读笔记

2. 问题定义 时间序列关系数据(Time Series Relation Data) 这个数据是存放在关系型数据库中,每一条记录都是泰永时间搓的行为。 更具体地,每条记录表示为 x ( v , t , x 1 , x 2 , … , x m − 2 ) x (v,t,x_1,x_2,\dots,x…

是德科技keysight N9917A微波分析仪

181/2461/8938产品概述: N9917A 是一款使用电池供电的便携式分析仪;基本功能是电缆和天线分析;配置还包括频谱和网络分析仪、可选的内置功率计和矢量电压表。 N9917A FieldFox 手持式微波分析仪 主要特性和功能 18 GHz 最大频率&#xfef…

docker 容器挂掉,无法exec 进入bash 怎么修改容器里的文件

在使用tdengine 数据库时出现了 TDengine.Driver.TDengineError:“code:[0x334],error:Out of dnodes” 查找文档发现需要修改一个配置文件 。 /etc/taos/taos.cfg 中的 supportVnodes 参数 于是修改 保存。然后,运行出错。 03/21 06:56:27.986498 00000064 …

UE5 C++增强输入

一.创建charactor,并且包含增强输入相关的头文件 1.项目名.build.cs。添加模块“EnhancedInput”,方便找到头文件和映射的一些文件。 PublicDependencyModuleNames.AddRange(new string[] { "Core", "CoreUObject", "Engine&q…

服务器总是宕机问题记录

博主介绍: 22届计科专业毕业,来自湖南,主要是在CSDN记录一些自己在Java开发过程中遇到的一些问题,欢迎大家一起讨论学习,也欢迎大家的批评指正。 文章目录 背景调整总结 背景 2核2G的服务器,服务器安装了t…

Spark Rebalance hint的倾斜的处理(OptimizeSkewInRebalancePartitions)

背景 本文基于Spark 3.5.0 目前公司在做小文件合并的时候用到了 Spark Rebalance 这个算子,这个算子的主要作用是在AQE阶段的最后写文件的阶段进行小文件的合并,使得最后落盘的文件不会太大也不会太小,从而达到小文件合并的作用,…

CSS弹性盒模型(学习笔记)

一、厂商前缀 1.1 作用 解决浏览器对C3新特性的兼容,不同的浏览器厂商,定义了自己的厂商前缀 1.2 语法 浏览器 厂商前缀内核(渲染引擎):解析htmlcssjs谷歌 -webkit-blink苹果-webkit-webkit欧朋-o-blink火狐 -moz-geckoIE-ms- trid…

sentinel流控规则详解(图形化界面)

1、QPS 每秒的请求数&#xff0c;超过这个请求数&#xff0c;直接流控。 2、BlockException统一异常处理 2.1、创建异常返回类 Result.class public class Result<T> {private Integer code;private String msg;private T data;public Result(Integer code, String ms…

算法设计与分析-贪心算法的应用动态规划算法的应用——沐雨先生

1&#xff0e; 理解贪心算法的概念&#xff1b; 2&#xff0e;掌握贪心算法的基本思想。 &#xff08;1&#xff09;删数问题 问题描述&#xff1a;给定 n 位正整数 a &#xff0c;去掉其中任意 k ≤ n 个数字后&#xff0c;剩下的数字按原次序排列组成一个新的正整数。对于…

Eclipse For ABAP:安装依赖报错

1.安装好Eclipse后需要添加依赖,这里的地址: https://tools.hana.ondemand.com/latest 全部勾选等待安装结束; 重启后报错:ABAP communication layer is not configured properly. This might be caused by missing Microsoft Visual C++ 2013 (x64) Runtime DLLs. Consu…

【面试经典150 | 数组】分发糖果

文章目录 写在前面Tag题目来源解题思路方法一&#xff1a;贪心两次遍历 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本题涉及到的数据结构等内容…

STM32最小核心板使用HAL库实现CAN接口通讯(轮询方式)

这里使用了CAN1的接口&#xff0c;具体使用MX创建项目就不放了 需要注意的是&#xff0c;由于是最小核心没有CAN的收发模块需要外接一个 STM32核心板接CAN收发模块不需要交叉 /**CAN GPIO ConfigurationPA11 ------> CAN_RXPA12 ------> CAN_TX */ CAN收发模块…

数据结构知识总结

二叉树 满二叉树 特性 所有叶子结点都集中在二叉树的最下面一层上&#xff0c;而且结点总数为&#xff1a;2^n-1 (n为层数 / 高度&#xff09; 完全二叉树 特性 若设二叉树的高度为h&#xff0c;除第h层外&#xff0c;其他各层的节点数都达到最大个数&#xff0c;第h层有…

Windows 设置多显示器显示

Windows 设置多显示器显示 1. Windows 7 设置 HDMI 输出2. Windows 11 设置多显示器显示References 1. Windows 7 设置 HDMI 输出 2. Windows 11 设置多显示器显示 ​​​ References [1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

MAC本安装telnet

Linux运维工具-ywtool 目录 1.打开终端1.先安装brew命令2.写入环境变量4.安装telnet 1.打开终端 访达 - 应用程序(左侧) - 实用工具(右侧) - 终端 #注意:登入终端用普通用户,不要用MAC的root用户1.先安装brew命令 /bin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/H…