Tensorflow——第三讲神经网络八股

news2024/12/28 20:07:59

前两讲我们学习了使用tensorflow原生代码搭建神经网络,本讲主要学习使用Tensorflow API:tf.keras搭建神经网络

一、搭建网络八股Sequential

六步法:

1.import:import 相关模块,如 import tensorflow as tf

2.train, test:指定输入网络的训练集和测试集,如指定训练集的输入 x_train 和标签
y_train,测试集的输入 x_test 和标签 y_test。

3.model = tf.keras.models.Sequential:逐层搭建网络结构

4.model.compile:在 model.compile()中配置训练方法,选择训练时使用的优化器、损失
函数和最终评价指标。

5.model.fit:在 model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、
每个 batch 的大小(batchsize)和数据集的迭代次数(epoch)

6.model.summary:使用 model.summary()打印网络结构,统计参数数目。

model = tf.keras.models.Sequential的使用:

 model.compile的使用

:from_logits=False:神经网络末端如果使用了softmax函数,输出为概率分布而不是原始输出,from_logits就为false,否则为True 

model.fit()的使用 

 model.summary()的使用 

二、搭建网络八股class 

Sequential能搭建上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构,这个时候我们可以选择用类class搭建神经网络结构。

class的使用 :

对比 Sequential和class搭建神经网络的过程:

以实现鸢尾花分类为例

Sequential

import tensorflow as tf
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

model.summary()

 class

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

class IrisModel(Model):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, x):
        y = self.d1(x)
        return y

model = IrisModel()

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

三、MNIST数据集 —手写数字识别训练

1.数据集的介绍

(1)MNIST数据集:
提供 6万张 28*28 像素点的0~9手写数字图片和标签,用于训练。
提供 1万张 28*28 像素点的0~9手写数字图片和标签,用于测试。

(2)导入MNIST数据集:
mnist = tf.keras.datasets.mnist
(x_train, y_train) ,  (x_test, y_test) = mnist.load_data()

(3)作为输入特征,输入神经网络时,将数据拉伸为一维数组:
tf.keras.layers.Flatten( )
[ 0   0   0  48 238 252 252 …… …… …… 253 186  12   0   0   0   0   0]

注:不知道这里大家有没有这样一个疑问,为什么鸢尾花的数据集不需要拉伸:

原因:鸢尾花数据集不需要拉直为一维是因为它的特征已经是数值型的,可以直接用于机器学习模型的训练和预测。而手写数字数据需要拉直为一维是因为它们的原始数据是图像形式的,需要通过转换才能被机器学习算法处理。

(4)观察数据集

2.代码实现书写数字识别

import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        y = self.d2(x)
        return y


model = MnistModel()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

后面还有FASHION数据集数据集,与MNIST数据集处理方式类似,就不再赘述。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1983482.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年7月30日~2024年8月5日周报

一、前言 上周继续修改论文,并阅读了两篇论文。 本周主要修改论文、完成实验、参加一些组会与论文讨论会,并配置了torch环境。 二、完成情况 2.1 论文符号系统注意事项 数学符号应该有唯一性,不能与其他符号造成误解;W_{\mathr…

c++初阶-----适配器---priority_queue

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…

极狐GitLab CICD Catalog Beta 功能介绍

极狐GitLab 是 GitLab 在中国的发行版,专门面向中国程序员和企业提供企业级一体化 DevOps 平台,用来帮助用户实现需求管理、源代码托管、CI/CD、安全合规,而且所有的操作都是在一个平台上进行,省事省心省钱。可以一键安装极狐GitL…

【Python】数据类型之列表(下)

(6)清空列表 功能:clear() 代码示例: (7)根据值获取索引(从左到右找到第一个返回索引)【慎用,找不到报错】 功能:index(xyz),xyz为数据类型。 …

OpenAI gym player mode

题意:OpenAI gym 的自定义模式 问题背景: Does anyone know how to run one of the OpenAI gym environments as a player. As in letting a human player play a round of cart pole? I have seen that there is env.mode human but I have not been…

波兰表达式求值

from operator import add, sub, muldef div(x, y):# 使用整数除法的向零取整方式return int(x / y) if x * y > 0 else -(abs(x) // abs(y))class Solution(object):op_map {: add, -: sub, *: mul, /: div}def evalRPN(self, tokens: List[str]) -> int:stack []for …

【C基础-按要求找数】一个整数,它加上100后是一个完全平方数,再加上168又是一个完全平方数,请问该数是多少

一个整数,它加上100后是一个完全平方数,再加上168又是一个完全平方数,请问该数是多少 完全平方数是指一个整数能够表示为某个整数的平方。换句话说,如果存在一个整数 n,使得 n^2m,那么 m 就是一个完全平方数。 使用C…

第二十一天培训笔记

上午 1 、环境准备 2 、安装 mysql 绿包 3 、配置 mysql 工作环境 mysql -hip 地址 -p3306 -uroot -p (远程连接使用) 4 、 mysql 基础命令 ( 1 )修改密码 ( 2 )授权远程登录 ( 3 &#x…

程序员短视频上瘾综合症

一、是你疯了还是面试官疯了? ​ 最近有两个学员咨询问题,把我给整得苦笑不得。大家来看看,你有没有同样的症状。 ​ 第一个学员说去一家公司面试,第一轮面试聊得挺好的。第二轮面试自我感觉良好,但是被面试官给Diss…

模型优化学习笔记—对比各种梯度下降算法

import mathimport numpy as np from opt_utils import * import matplotlib.pyplot as plt# 标准梯度下降 def update_parameters_with_gd(parameters, grads, learning_rate):L len(parameters) // 2for l in range(1, L 1):parameters[f"W{l}"] parameters[f&q…

【uniapp】聊天记录列表长按消息计算弹出菜单方向

1. 效果图 1.1 消息靠上接近导航栏&#xff0c;菜单显在消息体下方弹出&#xff0c;箭头向上 1.2 消息体没有贴近上方导航栏&#xff0c;菜单在消息体上方弹出&#xff0c;箭头向下 1.3 长消息&#xff0c;菜单在手指按下的位置弹出&#xff0c;无箭头 2. 代码实现 <view …

sqli 1- 10

sql靶场 第一关 首先我们需要判断是否存在sql注入点&#xff0c;前端界面提示我使用ID作为参数,在url地址栏输入?id1 通过输入不同的id值查询数据库相对应的内容&#xff0c;之后判断为数字型还是字符型 根据查询内容判断为字符型且有注入点&#xff0c;再通过联合查询&…

Vitis AI 基本操作+模型检查(inspector)用法详解

目录 1. 简介 2. 代码详解 2.1 导入所需的库 2.2 创建 Inspector 2.3 下载模型 2.4 检查模型 3. 其他有用函数 3.1 查看 torchvision 中模型 3.2 保存模型 3.2.1 保存模型参数 3.2.2 保存完整模型 3.2.3 加载模型 4. 总结 1. 简介 在《Vitis AI 构建开发环境&…

GNSS相关知识

各定位系统的频段&#xff1a; SystemSignalFrequency(MHz)GPSL1C/A1575.42L1C1575.42L2C1227.6L2P1227.6L51176.45   GLONASSL1C/A1598.0625-1609.3125L2C1242.9375-1251.6875L2P1242.9375-1251.6875L3OC1202.025   GalileoE11575.42E5a1176.45E5b1207.14E5AltBOC1191.…

SpringBoot之外部化配置

前言 SpringBoot 版本 2.6.13&#xff0c;相关链接 Core Features Default properties (specified by setting SpringApplication.setDefaultProperties).PropertySource annotations on your Configuration classes. Please note that such property sources are not added …

如何在群晖NAS中搭建影音管理利器nastool并实现远程访问本地资源

文章目录 前言1. 本地搭建Nastool2. nastool基础设置3. 群晖NAS安装内网穿透工具4. 配置公网地址5. 配置固定公网地址 前言 Nastool是为群晖NAS玩家量身打造的一款智能化影音管理利器。它不仅能够满足电影发烧友、音乐爱好者和追剧达人的需求&#xff0c;更能让你在繁忙的生活…

疯狂的马达——Arduino

本次学习目标 1、了解马达的运用、以及马达内部的基本原理。 2、学会通过编程控制马达的速度、方向。 3、制作电位器换挡风扇。 马达 “马达”为英语motor的音译&#xff0c;我们称为电机&#xff0c;电机又可分为 发电机和电动机。前者是一种能够将动能转化电能的装置&am…

【知识】pytorch中的pinned memory和pageable memory

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 目录 概念简介 pytorch用法 速度测试 反直觉情况 概念简介 默认情况下&#xff0c;主机 &#xff08;CPU&#xff09; 数据分配是可分页的。GPU 无…

计算机系统的基本结构-CSP初赛知识点整理

真题练习 [2021-CSP-J-第3题] 目前主流的计算机储存数据最终都是转换成&#xff08; &#xff09;数据进行储存。 A.二进制 B.十进制 C.八进制 D.十六进制 [2020-CSP-J-第1题] 在内存储器中每个存储单元都被赋予一个唯一的序号&#xff0c;称为( ) A&#xff0e;地址 B&a…

探索 Electron 应用的本地存储:SQLite3 与 Knex.js 的协同工作

electron 简介 Electron 是一个使用 JavaScript, HTML 和 CSS 构建跨平台桌面应用程序的框架。 它允许开发者使用 Web 技术来创建桌面软件&#xff0c;而不需要学习特定于平台的编程语言。 Electron 应用程序实际上是一个包含 Web 内容的 Chromium 浏览器实例&#xff0c;并…