17- TensorFlow中使用Keras创建模型 (TensorFlow系列) (深度学习)

news2024/11/17 11:52:40

知识要点

  • Keras 是一个用 Python 编写的高级神经网络 API
  • 数据的开方:  np.sqrt(784)       # 28
  • 代码运行调整到 CPU 或者 GPU:
import tensorflow as tf
cpu=tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpu)
  • 模型显示: model.summary()

创建模型:

  • 模型创建: model = Sequential()
  • 添加卷积层: model.add(Dense(32, activation='relu', input_dim=100))  # 第一层需要 input_dim
  • 添加dropout: model.add(Dropout(0.2))
  • 添加第二次网络: model.add(Dense(512, activation='relu'))   # 除了first, 其他层不要输入shape
  • 添加输出层: model.add(Dense(num_classes, activation='softmax') # last 通常使用softmax
  • TensorFlow 中,使用 model.compile 方法来选择优化器和损失函数:
    • optimizer: 优化器: 主要有: tf.train.AdamOptimizer , tf.train.RMSPropOptimizer , or tf.train.GradientDescentOptimizer .

    • loss: 损失函数: 主要有:mean square error (mse, 回归), categorical_crossentropy (多分类) , and binary_crossentropy (二分类).

    • metrics: 算法的评估标准, 一般分类用accuracy.

model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])
  • model.fit(x_train, y_train, batch_size = 64, epochs = 20, validation_data = (x_test, y_test))    # 模型训练

模型评估:

  • score = model.evaluate(x_test, y_test, verbose=0)    两个返回值: [ 损失率 , 准确率 ]

大数据集处理:

  • 大数据集数据变成dataset: dataset = tf.data.Dataset.from_tensor_slices((data, labels))
    • 指定每批数据大小: dataset = dataset.batch(32).repeat()
    • dataset 数据训练: model.fit(dataset, epochs=10, steps_per_epoch=30)
  • 保存模型: model.save('my_model.h5')
  • 加载模型: model = tf.keras.models.load_model('my_model.h5')


1 Keras 基础

1.1 简介

Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow , CNTK 或者 Theano 作为后端运行。在Keras的官方github上写着"Deep Learning for humans", 主要是因为它能简单快速的创建神经网络,而不需要像Tensorflow一样考虑很多中间过程.

Keras说白了就是一个壳子, 需要结合TensorFlow, CNTK或者Theano等后端框架来运行.基于这些特点, Keras的入门非常简单. TensorFlow已经把keras集成了.

中文官方文档: 主页 - Keras 中文文档

1.2 Keras 使用

1.2.1 创建模型

  • 最简单的使用方式是使用Keras Sequential, 中文叫做顺序模型, 可以用来表示多个网络层的线性堆叠. 使用的时候可以通过讲网络层实例的列表传递给Sequential, 比如:
from keras.models import Sequential
from keras.layers import Dense, Activation

model = Sequential([
    Dense(32, input_shape=(784,)),
    Activation('relu'),
    Dense(10),
    Activation('softmax'),
])
model.summary()

 

  • 也可以简单的使用.add()方法将各层添加到模型中:
model = Sequential()
model.add(Dense(32, input_dim=784))
model.add(Activation('relu'))
model.summary()

 1.2.2 模型参数设置

对于输入层,需要指定输入数据的尺寸,通过Dense对象中的input_shape属性.注意无需写batch的大小. input_shape=(784,)  等价于我们在神经网络中定义的shape为(None, 784)的Tensor.

模型创建成功之后,需要进行编译.使用.compile()方法对创建的模型进行编译.compile()方法主要需要指定一下几个参数:

  • 优化器 optimizer: 可以是Keras定义好的优化器的字符串名字,比如'rmsprop'也可以是Optimizer类的实例对象.常见的优化器有: SGD, RMSprop, Adagrad, Adadelta等.

  • 损失函数 loss: 模型视图最小化的目标函数, 它可以是现有损失函数的字符串形式, 比如:categorical_crossentropy, 也可以是一个目标函数.

  • 评估标准 metrics. 评估算法性能的衡量指标.对于分类问题, 建议设置为metrics = ['accuracy'].评估标准可以是现有的标准的字符串标识符,也可以是自定义的评估标准函数。

以下为compile的常见写法:

# 多分类问题
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 二分类问题
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# 均方误差回归问题
model.compile(optimizer='rmsprop',
              loss='mse')

# 自定义评估标准函数
import keras.backend as K

def mean_pred(y_true, y_pred):
    return K.mean(y_pred)

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy', mean_pred])
model.summary()

  • 训练模型: 使用.fit()方法,将训练数据,训练次数(epoch), 批次尺寸(batch_size)传递给fit()方法.
# 训练模型,以 32 个样本为一个 batch 进行迭代
model.fit(data, labels, epochs=10, batch_size=32)

1.3 Keras函数式API

Sequential 顺序模型封装了太多东西,不够灵活,如果你想定义复杂模型可以使用Keras的函数式API.

以下是一个全连接网络的例子:    # 手写数字识别

from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import tensorflow as tf

# 导入手写数字数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对数据进行初步处理
x_train = x_train.reshape(60000, 784)
x_train = x_train.astype('float32')
x_train /= 255
# 将标记结果转化为独热编码
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)

# 这部分返回一个张量
inputs = Input(shape=(784))
# 层的实例是可调用的,它以张量为参数,并且返回一个张量
output_1 = Dense(64, activation='relu')(inputs)
output_2 = Dense(64, activation='relu')(output_1)
predictions = Dense(10, activation='softmax')(output_2)

# 这部分创建了一个包含输入层和三个全连接层的模型
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train)  # 开始训练

这个例子能够帮助我们进行一些简单的理解。

  • 网络层的实例是可调用的,它以张量为参数,并且返回一个张量

  • 输入和输出均为张量,它们都可以用来定义一个模型 (Model)

  • 这样的模型同 Keras 的 Sequential 模型一样,都可以被训练.

2. TensorFlow中使用Keras

keras集成在tf.keras中.

2.1 创建模型

创建一个简单的模型,使用 tf.keras.sequential.

import tensorflow as tf
import tensorflow.keras.layers as layers
model = tf.keras.Sequential()
# 创建一层有64个神经元的网络:
model.add(layers.Dense(64, activation='relu'))
# 添加另一层网络:
model.add(layers.Dense(64, activation='relu'))
# 输出层:
model.add(layers.Dense(10, activation='softmax'))

2.2 配置layers

layers 包含以下三组重要参数:

  • activation: 激活函数, 'relu', 'sigmoid', 'tanh'.

  • kernel_initializer bias_initializer: 权重和偏差的初始化器. Glorot uniform是默认的初始化器.一般不用改.

  • kernel_regularizerbias_regularizer : 权重和偏差的正则化: L1, L2.

以下是配置模型的例子:

import tensorflow.keras.layers as layers
# 激活函数为sigmoid:
layers.Dense(64, activation='sigmoid')
# Or:
layers.Dense(64, activation=tf.sigmoid)

# 权重加了L1正则:
layers.Dense(64, kernel_regularizer=tf.keras.regularizers.l1(0.01))

# 给偏差加了L2正则
layers.Dense(64, bias_regularizer=tf.keras.regularizers.l2(0.01))

# 随机正交矩阵初始化器:
layers.Dense(64, kernel_initializer='orthogonal')

# 偏差加了常数初始化器
layers.Dense(64, bias_initializer=tf.keras.initializers.constant(2.0))

2.3 训练和评估

2.3.1 配置模型

使用compile配置模型, 主要有以下几组重要参数.

  • optimizer: 优化器: 主要有: tf.train.AdamOptimizer , tf.train.RMSPropOptimizer , or tf.train.GradientDescentOptimizer .

  • loss: 损失函数: 主要有:mean square error (mse, 回归), categorical_crossentropy (多分类) , and binary_crossentropy (二分类).

  • metrics: 算法的评估标准, 一般分类用accuracy.

以下是compile的 实例:

# 配置均方误差的回归.
model.compile(optimizer = tf.train.AdamOptimizer(0.01),
              loss = 'mse',       # mean squared error
              metrics = ['mae'])  # mean absolute error

# 配置多分类的模型.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
              loss=tf.keras.losses.categorical_crossentropy,
              metrics=[tf.keras.metrics.categorical_accuracy])

2.3.2 训练

使用model的fit方法进行训练, 主要有以下参数:

  • epochs: 训练次数

  • batch_size: 每批数据多少

  • validation_data: 测试数据

对于小数量级的数据,可以直接把训练数据传入fit.

import numpy as np

data = np.random.random((1000, 32))
labels = random_one_hot_labels((1000, 10))

val_data = np.random.random((100, 32))
val_labels = random_one_hot_labels((100, 10))

model.fit(data, labels, epochs=10, batch_size=32,
          validation_data=(val_data, val_labels))

对于大数量级的训练数据,使用tensorflow中dataset.

# 把数据变成dataset
dataset = tf.data.Dataset.from_tensor_slices((data, labels))
# 指定一批数据是32, 并且可以无限重复
dataset = dataset.batch(32).repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(32).repeat()

# 别忘了steps_per_epoch, 表示执行完全部数据的steps
model.fit(dataset, epochs=10, steps_per_epoch=30)

model.fit(dataset, epochs=10, steps_per_epoch=30,
          validation_data=val_dataset,
          validation_steps=3)

2.3.3 评估和预测

使用 tf.keras.Model.evaluate and tf.keras.Model.predict进行评估和预测. 评估会打印算法的损失和得分.

data = np.random.random((1000, 32))
labels = random_one_hot_labels((1000, 10))
#  普通numpy数据
model.evaluate(data, labels, batch_size=32)
# tensorflow dataset数据
model.evaluate(dataset, steps=30)

预测:

result = model.predict(data, batch_size=32)
print(result.shape)

2.4 使用函数式API

函数式API, 主要是需要自己把各个组件的对象定义出来, 并且手动传递.

inputs = tf.keras.Input(shape=(32,))  # 返回placeholder
# A layer instance is callable on a tensor, and returns a tensor.
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dense(64, activation='relu')(x)
predictions = layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
# The compile step specifies the training configuration.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
# Trains for 5 epochs
model.fit(data, labels, batch_size=32, epochs=5)

2.5 保存和恢复

  • 使用model.save把整个模型保存为HDF5文件
model.save('my_model.h5')
  • 恢复使用tf.keras.models.load_model即可.
model = tf.keras.models.load_model('my_model.h5')

注意: 如果使用的tensorflow的optimizer, 那么保存的model中没有model配置信息, 恢复以后需要重新配置.推荐用keras的optimizer.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/375316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tik Tok品牌营销,如何做好内容打法

TikTok 上做好品牌营销,并不能只关注品牌所获得的视频浏览量和点赞量,根据潜在客户需求生成的内容策略同样至关重要。通过建立营销漏斗模型,可以将 TikTok 策略分为三种不同类型的内容,从具有广泛吸引力的内容转变为具有高度针对性…

Vue组件是怎样挂载的

我们先来关注一下$mount是实现什么功能的吧&#xff1a; 我们打开源码路径core/instance/init.js: export function initMixin (Vue: Class<Component>) {......initLifecycle(vm)// 事件监听初始化initEvents(vm)initRender(vm)callHook(vm, beforeCreate)initInject…

FastDDS-3. DDS层

3. DDS层 eProsima Fast DDS公开了两个不同的API&#xff0c;以在不同级别与通信服务交互。主要API是数据分发服务&#xff08;DDS&#xff09;数据中心发布订阅&#xff08;DCPS&#xff09;平台独立模型&#xff08;PIM&#xff09;API&#xff0c;简称DDS DCPS PIM&#xf…

nacos集群模式+keepalived搭建高可用服务

实际工作中如果nacos这样的核心服务停掉了或者整个服务器宕机了&#xff0c;那整个系统也就gg了&#xff0c;所以像这样的核心服务我们必须要搞个3个或者3个以上的nacos集群部署&#xff0c;实现高可用&#xff1b; 部署高可用版本之前&#xff0c;首先你要会部署单机版的naco…

[2]MyBatis+Spring+SpringMVC+SSM整合一套通关

二、Spring 1、Spring简介 1.1、Spring概述 官网地址&#xff1a;https://spring.io/ Spring 是最受欢迎的企业级 Java 应用程序开发框架&#xff0c;数以百万的来自世界各地的开发人员使用 Spring 框架来创建性能好、易于测试、可重用的代码。 Spring 框架是一个开源的 Jav…

激活函数入门学习

本篇文章从外行工科的角度尽量详细剖析激活函数&#xff0c;希望不吝指教&#xff01; 学习过程如下&#xff0c;先知道这个东西是什么&#xff0c;有什么用处&#xff0c;以及怎么使用它&#xff1a; 1. 为什么使用激活函数 2. 激活函数总类及优缺点 3. 如何选择激活函数 …

一篇了解模块打包工具之 ——webpack(1)

本篇采用问题引导的方式来学习webpack&#xff0c;借此梳理一下自己对webpack的理解&#xff0c;将所有的知识点连成一条线&#xff0c;形成对webpack的记忆导图。 最终目标&#xff0c;手动构建一个vue项目&#xff0c;目录结构参考vue-cli创建出来的项目 一、问问题 1. 第…

Echarts 仪表盘倾斜一定角度显示,非中间对称

第024个点击查看专栏目录大多数的情况下&#xff0c;制作的仪表盘都是中规中矩&#xff0c;横向中间对称&#xff0c;但是生活中的汽车&#xff0c;摩托车等仪表盘确是要倾斜一定角度的&#xff0c;Echarts中我们就模拟一个带有倾斜角度的仪表盘。核心代码见示例源代码 文章目录…

搞明白redis的这些问题,你就是redis高手

什么是redis? Redis 本质上是一个 Key-Value 类型的内存数据库&#xff0c; 整个数据库加载在内存当中进行操作&#xff0c; 定期通过异步操作把数据库数据 flush 到硬盘上进行保存。 因为是纯内存操作&#xff0c; Redis 的性能非常出色&#xff0c; 每秒可以处理超过 10 万…

JS 快速创建二维数组 fill方法的坑点

JS 快速创建二维数组 坑 在算法中&#xff0c;创建二维数组遇到的一个坑 const arr new Array(5).fill(new Array(2).fill(1))我们如果想要修改其中一个元素的值 arr[0][1] 5我们可以发现所有数组中的第二个元素都发生了改变 查看MDN&#xff0c;我们会发现&#xff0c;当…

2023前端二面经典手写面试题

实现一个call call做了什么: 将函数设为对象的属性执行&删除这个函数指定this到函数并传入给定参数执行函数如果不传入参数&#xff0c;默认指向为 window // 模拟 call bar.mycall(null); //实现一个call方法&#xff1a; Function.prototype.myCall function(context…

一篇搞懂springboot多数据源

好文推荐 https://zhuanlan.zhihu.com/p/563949762 mybatis 配置多数据源 参考文章 https://blog.csdn.net/qq_38353700/article/details/118583828 使用mybatis配置多数据源我接触过的有两种方式&#xff0c;一种是通过java config的方式手动配置两个数据源&#xff0c;…

01、SVN 概述

SVN 概述1 概述2 功能3 工作原理4 基本操作1 概述 Apache下的一个开源的项目Subversion&#xff0c;通常缩写为 SVN&#xff0c;是一个版本控制系统版本控制系统是一个软件&#xff0c;它可以伴随我们软件开发人员一起工作&#xff0c;让编写代码的完整的历史保存下来目前它的…

数仓基础与hive入门

目录1、数仓数据仓库主流开发语言--SQL2、Apache Hive入门2.1 hive定义2.2 为什么使用Hive2.3 Hive和Hadoop关系2.4 场景设计&#xff1a;如何模拟实现Hive功能2.5 Apache Hive架构、组件3、Apache Hive安装部署3.1 metastore配置方式4、Hive SQL语言&#xff1a;DDL建库、建表…

内存保护_2:RTA-OS内存保护逻辑及配置说明

上一篇 | 返回主目录 | 下一篇 内存保护_2&#xff1a;RTA-OS内存保护逻辑及配置说明3 OS配置说明3.1 OS一些基本概念及相互关系3.1.1 基本概念3.1.2 相互关系3.2 内存保护基本逻辑&#xff08;RTA-OS&#xff09;3.2.1 应用集的基本分类3.2.2 内存保护与应用集的关系3.3 OS等级…

七大排序(Java)

目录 一、插入排序 1. 直接插入排序 2. 希尔排序 二、选择排序 1. 直接选择排序 2. 堆排序 三、交换排序 1. 冒泡排序 2. 快速排序 四、归并排序 五、总结 一、插入排序 1. 直接插入排序 抓一张牌&#xff0c;在有序的牌中&#xff0c;找到合适的位置并且插入。 时间…

三战阿里测试岗,成功上岸,面试才是测试员涨薪真正的拦路虎...

第一次面试阿里记得是挂在技术面上&#xff0c;当时也是技术不扎实&#xff0c;准备的不充分&#xff0c;面试官出的面试题确实把我问的一头雾水&#xff0c;还没结束我就已经知道我挂了这次面试。 第二次面试&#xff0c;我准备的特别充分&#xff0c;提前刷了半个月的面试题…

防止jar被反编译 不安装jdk运行jar

防止jar被反编译1.pom.xml<repositories><repository><id>jitpack</id><url>https://jitpack.io</url></repository> </repositories><dependencies><dependency><groupId>org.openjfx</groupId><…

RabbitMQ死信队列

目录 一、概念 二、出现死信的原因 三、实战 &#xff08;一&#xff09;代码架构图 &#xff08;二&#xff09;消息被拒 &#xff08;三&#xff09;消息TTL过期 &#xff08;四&#xff09;队列达到最大长度 一、概念 先从概念解释上搞清楚这个定义&#xff0c;死信&…

Spark 3.3.x 读取 HBase 2.x 异常(无法正常连接或读取数据)

无法连接 1. 先检查集群中的 HBase 服务、ZooKeeper 服务是否正常启动&#xff0c;有没有挂掉。 2. Spark 中的 HBase 版本是否与集群一致&#xff0c;代码中的相关包是否导入正确。 3. 连接参数&#xff08;地址、端口&#xff09;是否设置正确&#xff0c;如下所示&#x…