TensorFlow2实战-系列教程2:神经网络分类任务

news2025/1/13 10:21:06

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、Mnist数据集

下载mnist数据集:

%matplotlib inline
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

制作数据:

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

简单展示数据:

from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
print(y_train[0])

打印结果:

(50000, 784)
5

在这里插入图片描述

2、模型构建

在这里插入图片描述
在这里插入图片描述
输入为784神经元,经过隐层提取特征后为10个神经元,10个神经元的输出值经过softmax得到10个概率值,取出10个概率值中最高的一个就是神经网络的最后预测值

构建模型代码:

import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

选择损失函数,损失函数是机器学习一个非常重要的部分,基本直接决定了这个算法的效果,这里是多分类任务,一般我们就直接选用多元交叉熵函数就好了:
TensorFlow损失函数API

编译模型:

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.SparseCategoricalCrossentropy(),
             metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  1. adam优化器,学习率为0.001
  2. 多元交叉熵损失函数
  3. 评价指标

模型训练:

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_valid, y_valid))

训练数据,训练标签,训练轮次,batch_size,验证集

打印结果:

Train on 50000 samples, validate on 10000 samples
Epoch 1/5 50000/50000  1s 29us
sample-loss: 115566 - sparse_categorical_accuracy: 0.1122 - val_loss: 364928.5786 - val_sparse_categorical_accuracy: 0.1064
Epoch 2/5 50000/50000 1s 21us
sample - loss: 837104 - sparse_categorical_accuracy: 0.1136 - val_loss: 1323287.7028 - val_sparse_categorical_accuracy: 0.1064
Epoch 3/5 50000/50000 1s 20us
sample - loss: 1892431 - sparse_categorical_accuracy: 0.1136 - val_loss: 2448062.2680 - val_sparse_categorical_accuracy: 0.1064
Epoch 4/5 50000/50000 1s 20us
sample - loss: 3131130 - sparse_categorical_accuracy: 0.1136 - val_loss: 3773744.5348 - val_sparse_categorical_accuracy: 0.1064
Epoch 5/5 50000/50000 1s 20us
sample - loss: 4527781 - sparse_categorical_accuracy: 0.1136 - val_loss: 5207194.3728 - val_sparse_categorical_accuracy: 0.1064
<tensorflow.python.keras.callbacks.History at 0x1d3eb9015f8>

模型保存:

model.save('Mnist_model.h5')

3、TensorFlow常用模块

3.1 Tensor格式转换

创建一组数据

import numpy as np
input_data = np.arange(16)
input_data

打印结果:
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

转换成TensorFlow格式的数据:

dataset = tf.data.Dataset.from_tensor_slices(input_data)
for data in dataset:
    print (data)

将一个ndarray转换成
打印结果:
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

3.2repeat操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2)
for data in dataset:
    print (data)

tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

会将当前的输出重复一遍

3.3 batch操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2).batch(4)
for data in dataset:
    print (data)

tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)

将原来的数据按照4个为一个批次

3.4 shuffle操作

dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(buffer_size=10).batch(4)
for data in dataset:
    print (data)

tf.Tensor([ 9 8 11 3], shape=(4,), dtype=int32)
tf.Tensor([ 5 6 1 13], shape=(4,), dtype=int32)
tf.Tensor([14 15 4 2], shape=(4,), dtype=int32)
tf.Tensor([12 7 0 10], shape=(4,), dtype=int32)

shuffle操作,直接翻译过来就是洗牌,把当前的数据进行打乱操作
buffer_size=10,就是缓存10来进行打乱取数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1414300.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vs2019报错MSB4019 找不到导入的项目“BuildCustomizations\CUDA 9.2.props”

在VS中执行生成&#xff0c;报错如下&#xff1a;严重性 代码 说明 项目 文件 行 禁止显示状态 错误 MSB4019 找不到导入的项目“D:\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\CUDA 9.2.props”。请确认 Import 声明“D:\Microso…

在autodl训练yolov8时卡在下载字体

1.问题 在autodl训练yolov8到这一步之后会卡住很久 2. 解决办法 Ctric中断后发现是下载Arial字体卡住了&#xff0c;这个字体需要从外网中下载 先手动从链接中下载https://ultralytics.com/assets/Arial.ttf &#xff0c;然后上传到autodl。然后将这个文件移动到/root/.config/…

机电制造ERP软件有哪些品牌?哪家的机电制造ERP系统比较好

机电制造过程比较复杂&#xff0c;涵盖零配件、采购、图纸设计、工艺派工、生产计划、物料需求计划、委外加工等诸多环节。而供应链涉及供应商的选择、材料采购价格波动分析、材料交货、品质检验等过程&#xff0c;其中某个环节出现问题都可能会影响产品交期和经营效益。 近些…

一文速通Python添加、修改和删除字典元素

添加、修改和删除字典元素是 Python 中使用字典时常见的操作。字典是一种无序、可变的数据结构&#xff0c;用于存储键值对。在 Python 中&#xff0c;对字典元素进行添加、修改和删除操作可以帮助我们动态地管理数据&#xff0c;方便地根据需求对字典进行更新和维护。 一、添…

详讲api网关之kong的基本概念及安装和使用(一)

什么是api网关 前面我们聊过sentinel&#xff0c;用来限流熔断和降级&#xff0c;如果你只有一个服务&#xff0c;用sentinel自然没有问题&#xff0c;但是如果是有多个服务&#xff0c;特别是微服务的兴起&#xff0c;那么每个服务都使用sentinel就给系统维护带来麻烦。那么网…

Making Large Language Models Perform Better in Knowledge Graph Completion论文阅读

文章目录 摘要1.问题的提出引出当前研究的不足与问题KGC方法LLM幻觉现象解决方案 2.数据集和模型构建数据集模型方法基线方法任务模型方法基于LLM的KGC的知识前缀适配器知识前缀适配器 与其他结构信息引入方法对比 3.实验结果与分析结果分析&#xff1a;可移植性实验&#xff1…

那些年与指针的爱恨情仇(一)---- 指针本质及其相关性质用法

关注小庄 顿顿解馋 (≧∇≦) 引言&#xff1a; 小伙伴们在学习c语言过程中是否因为指针而困扰&#xff0c;指针简直就像是小说女主&#xff0c;它逃咱追&#xff0c;我们插翅难飞…本篇文章让博主为你打理打理指针这个傲娇鬼吧~ 本节我们将认识到指针本质&#xff0c;何为指针和…

k8s 版本发布与回滚

一、实验环境准备&#xff1a; kubectl get pods -o wide kubectl get nodes -o wide kubectl get svc 准备两个nginx镜像&#xff0c;版本号一个是V3&#xff0c;一个是V4 二、准备一个nginx.yaml文件 apiVersion: apps/v1 kind: Deployment metadata:name: nginx-deploylab…

解释性人工智能(XAI)—— AI 决策的透明之道

在当今数字化时代&#xff0c;人工智能&#xff08;AI&#xff09;已经成为我们生活中不可或缺的一部分。AI 系统的决策和行为对我们的生活产生了深远的影响&#xff0c;从医疗保健到金融服务再到自动驾驶汽车。 然而&#xff0c;有时候 AI 的决策似乎像黑盒子一样难以理解&am…

linux服务器ssh连接慢问题处理

一、 可能导致慢的几个原因 1、网络问题&#xff1a;网络延迟、带宽限制和包丢失等网络问题都有可能导致SSH连接变慢。 2、客户端设置&#xff1a;错误的客户端设置&#xff0c;如使用过高的加密算法或不适当的密钥设置&#xff0c;可能导致SSH连接变慢。 3、服务器负载过高…

element-ui 树形控件 实现点击某个节点获取本身节点和底下所有的子节点数据

1、需求&#xff1a;点击树形控件中的某个节点&#xff0c;需要拿到它本身和底下所有的子节点的id 1、树形控件代码 <el-tree:data"deptOptions"node-click"getVisitCheckedNodes"ref"target_tree_Speech"node-key"id":default-ex…

elasticsearch8的整体总结

es概述 elasticsearch简介 官网: https://www.elastic.co/ ElasticSearch是一个基于Lucene&#xff08;Apache开源全文检索工具包&#xff09;的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎&#xff0c;基于RESTful web接口。Elasticsearch是用Java开发的&…

MySQL:数据库索引详解

1、什么是索引&#xff1a; 索引是一种用于快速查询和检索数据的数据结构。常见的索引结构有: B 树&#xff0c; B树和 Hash。 索引的作用就相当于目录的作用。打个比方: 我们在查字典的时候&#xff0c;如果没有目录&#xff0c;那我们就只能一页一页的去找我们需要查的那个字…

基于comsol热黏性声学模块仿真声学超材料的声学特性

研究内容&#xff1a; 传统的声学吸收器被用于具有与工作波长相当的厚度的结构&#xff0c;这在低频范围的实际应用中造成了主要障碍。我们提出了一种基于超表面的完美吸收体&#xff0c;能够在极低频区域实现声波的完全吸收。具有深亚波长厚度至特征尺寸k&#xff1d;223的超…

基于Matlab/Simulink直驱式风电储能制氢仿真模型

接着还是以直驱式风电为DG中的研究对象&#xff0c;上篇博客考虑的风电并网惯性的问题&#xff0c;这边博客主要讨论功率消纳的问题。 考虑到风速是随机变化的&#xff0c;导致风电输出功率的波动性和间歇性问题突出&#xff1b;随着其应用规模的不断扩大以及风电在电网中渗透率…

【洛谷 P7072】[CSP-J2020] 直播获奖 题解(优先队列+对顶堆)

[CSP-J2020] 直播获奖 题目描述 NOI2130 即将举行。为了增加观赏性&#xff0c;CCF 决定逐一评出每个选手的成绩&#xff0c;并直播即时的获奖分数线。本次竞赛的获奖率为 w % w\% w%&#xff0c;即当前排名前 w % w\% w% 的选手的最低成绩就是即时的分数线。 更具体地&am…

Typora 无法导出 pdf 问题的解决

目录 问题描述 解决困难 解决方法 问题描述 Windows 下&#xff0c;以前&#xff08;Windows 11&#xff09; Typora 可以顺利较快地由 .md 导出 .pdf 文件&#xff0c;此功能当然非常实用与重要。 然而&#xff0c;有一次电脑因故重装了系统&#xff08;刷机&#xff09;…

【代码随想录15】110.平衡二叉树 257. 二叉树的所有路径 404.左叶子之和

目录 110. 平衡二叉树题目描述参考代码 257. 二叉树的所有路径题目描述参考代码 404.左叶子之和题目描述参考代码 110. 平衡二叉树 题目描述 给定一个二叉树&#xff0c;判断它是否是高度平衡的二叉树。 本题中&#xff0c;一棵高度平衡二叉树定义为&#xff1a; 一个二叉树…

亚马逊测评:卖家如何操作测评,安全高效(自养号测评)

亚马逊测评的作用在于让用户更真实、清晰、快捷地了解产品以及产品的使用方法和体验。通过买家对产品的测评&#xff0c;也可以帮助厂商和卖家优化产品缺陷&#xff0c;提高用户的使用体验。这进而帮助他们获得更好的销量&#xff0c;并更深入地了解市场需求。亚马逊测评在满足…

SAP同步异常4:删除合并特征数据的正确方案CXA01

测试环境VF02过帐报错。 原因&#xff0c;在处理测试环境异常数据ZZECCS时没有找到正确的方法&#xff0c;采用的是数据库直接删除。没有解决程序问题。 在SAP同步异常3&#xff1a;解决合并数据异常 只解决了一个程序问题。 最终解决方案&#xff1a; CXA01 删除ZZECCS表 …