深度学习框架探秘|Keras:深度学习的魔法钥匙

news2025/2/19 14:52:41

一、引言:深度学习浪潮中的 Keras

前面的文章我们探秘了深度学习框架中的两大明星框架 —— TensorFlow 和 PyTorch 以及 两大框架的对比
在深度学习的众多框架中,还有一款框架备受开发者们的喜爱 —— Keras 。它就像是一位贴心的助手,为我们搭建起了通往深度学习世界的便捷桥梁。无论你是初涉人工智能领域的小白,还是经验丰富的技术大咖,Keras 都能凭借其独特的魅力,满足你的各种需求,助力你在深度学习的海洋中畅快遨游。
在这里插入图片描述

二、Keras 是什么?

Keras一个基于 Python 编写的开源神经网络库,它就像是一个便捷的工具箱,为我们打造深度学习模型提供了各种趁手的 “工具”。它的出现,让深度学习模型的搭建变得更加轻松和高效。它能够以 TensorFlow , Microsoft-CNTK 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。能够以最小的时延把你的想法转换为实验结果,是做好研究的关键。

Keras 具有诸多显著优势:一是用户友好,API 设计简洁直观,新手也能快速上手,将模型构想转化为代码,如构建简单图像分类模型,用 Keras 十几行代码即可搭建基本框架,比其他复杂框架更高效。二是高度模块化,把神经网络各组件设计为独立模块,像搭积木般可按需组合搭建不同结构模型,如构建自然语言处理模型时可选择合适模块组合。三是易扩展性出色,有新需求添加模块时,仿照现有模块编写新类或函数就能轻松实现,在先进研究工作中也能发挥重要作用。下面我们来看看它的这些优势

三、Keras 的优势

(一)简洁易用的 API

Keras 的 API 设计堪称一绝,它就像是为开发者量身定制的贴心工具。以构建一个简单的手写数字识别模型为例,使用 Keras 只需短短数十行代码,就能轻松搭建起模型框架。如下是使用 Keras 搭建简单手写数字识别模型的代码示例:

from keras.models import Sequential
from keras.layers import Dense
from keras.datasets import mnist
from keras.utils import np_utils

# 加载数据
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# 数据预处理
X_train = X_train.reshape(X_train.shape[0], -1) / 255.0
X_test = X_test.reshape(X_test.shape[0], -1) / 255.0
y_train = np_utils.to_categorical(y_train, 10)
y_test = np_utils.to_categorical(y_test, 10)

# 构建模型
model = Sequential()
model.add(Dense(128, activation='relu', input_shape=(784,)))
model.add(Dense(10, activation='softmax'))

# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 训练模型
model.fit(X_train, y_train, batch_size=128, epochs=10, validation_data=(X_test, y_test))

# 评估模型
score = model.evaluate(X_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

在这个示例中,从数据的加载与预处理,到模型的构建、编译、训练以及最后的评估,整个过程代码简洁明了,逻辑清晰易懂 。相比之下,使用其他一些深度学习框架,如 TensorFlow 原生 API,可能需要编写大量复杂的代码来处理变量初始化、计算图构建等底层细节。

就好比同样是建造一座房子Keras 提供的是已经预制好的各种模块,你只需按照图纸进行拼接即可;而其他框架可能给你的是一堆原始材料,你需要从打地基开始,一步步地搭建,过程繁琐且容易出错。Keras 的这种简洁易用性,大大降低了深度学习的开发门槛,让更多人能够快速上手,将更多的时间和精力投入到模型的优化和业务逻辑的实现上。

(二)强大的兼容性

Keras 的兼容性令人赞叹,它支持多种后端,如 TensorFlow、CNTK、Theano 等 。这意味着开发者可以根据自己的需求和实际情况,灵活选择最适合的后端。例如,如果你的团队在 TensorFlow 生态积累了经验和资源,可选用 TensorFlow 作 Keras 后端,以利用其强大的分布式训练能力和广泛工具支持;若对微软认知工具包 CNTK 感兴趣或项目已用其功能,Keras 能无缝对接;若偏好 Theano 特性,Keras 也能满足需求。

Keras 多后端支持特性如万能钥匙,能融入各种开发环境,提供稳定高效服务,提高代码可移植性,让开发者依项目特点选合适后端,优化模型性能与开发效率。

(三)广泛的应用领域

Keras 凭借其强大的功能和出色的易用性,在众多领域都有着广泛的应用。

图像识别领域,Keras 可以助力我们实现各种复杂的任务。比如在医疗影像分析中,通过构建基于 Keras 的深度学习模型,能够帮助医生更准确地识别 X 光、CT 等影像中的病变,辅助疾病的诊断。以皮肤癌的诊断为例,研究人员利用 Keras 搭建卷积神经网络模型,对大量的皮肤病变图像进行训练,模型能够学习到不同病变的特征,从而准确地判断出病变是否为癌症,为早期诊断和治疗提供了有力的支持。在安防监控领域,Keras 也发挥着重要作用。通过训练图像识别模型,可以实现对监控视频中的人物、车辆等目标的实时检测和识别,提高安防系统的智能化水平,保障人们的生命财产安全。

自然语言处理领域,Keras 同样表现出色。在智能客服系统中,利用 Keras 构建的语言模型可以理解用户的提问,并给出准确的回答。例如,当用户咨询问题时,模型能够快速分析问题的语义,从知识库中检索相关信息,然后生成自然流畅的回复,大大提高了客服的工作效率和服务质量。在文本分类任务中,Keras 也能大显身手。比如对新闻文章进行分类,将其分为政治、经济、体育、娱乐等不同类别,方便用户快速获取感兴趣的信息。还可以用于情感分析,判断用户在社交媒体上发表的言论是积极、消极还是中性,帮助企业了解用户的情感倾向,优化产品和服务。

此外,在生物信息学领域,Keras 也开始崭露头角。研究人员可以利用 Keras 构建深度学习模型,对基因序列数据进行分析,预测基因的功能、疾病的发生风险等。比如通过分析大量的基因数据,预测某些基因突变与特定疾病之间的关联,为疾病的预防和治疗提供新的思路和方法。

四、Keras 的使用方法

(一)安装与配置

想要使用 Keras,首先得把它安装到你的电脑上。安装 Keras 其实并不复杂,不过在这之前,你需要确保你的系统里已经安装好了 Python,而且建议使用 Python 3.6 及以上的版本 ,就像搭建房子要先打好地基一样,Python 就是使用 Keras 的基础。

1、安装 Keras
安装 Keras 的依赖项是安装过程中的重要一步。如果你选择 TensorFlow 作为 Keras 的后端(这也是比较常见的选择),你可以在命令行中输入以下命令来安装:

pip install tensorflow

这条命令会自动帮你下载并安装 TensorFlow。安装完成后,就可以安装 Keras 了,同样在命令行中输入:

pip install keras

这样,Keras 就成功安装到你的系统中了。

2、配置 Keras
安装好之后,还需要对 Keras 进行配置。Keras 的配置主要是针对后端的选择和一些默认参数的设置。如果你想修改后端,比如从默认的 TensorFlow 后端切换到 Theano 后端,有两种方法可以实现。一种是修改 Keras 的配置文件keras.json ,这个文件通常位于$HOME/.keras/目录下($HOME表示你的用户主目录,在不同的操作系统中可能有所不同)。你可以使用文本编辑器打开这个文件,然后找到"backend"字段,将其值从"tensorflow"改为"theano" 。另一种方法是在 Python 代码中通过设置环境变量来修改后端,在导入 Keras 之前,添加以下代码:

import os

os.environ['KERAS_BACKEND'] = 'theano'

这样就可以在代码运行时临时将 Keras 的后端设置为 Theano。通过这些安装和配置步骤,你就为使用 Keras 做好了充分的准备,可以开启深度学习模型的搭建之旅啦。

(二)构建模型

在 Keras 中,构建模型主要有两种方式,分别是序贯模型和函数式模型,它们各有特点,适用于不同的场景。

1、序贯模型

序贯模型是一种最简单的模型构建方式,它就像是搭积木一样,按照顺序一层一层地堆叠网络层。每一层都只有一个输入和一个输出,前一层的输出会直接作为下一层的输入。我们以手写数字识别这个经典任务为例,来看看如何使用 Sequential 模型构建简单的神经网络。

手写数字识别,就是让计算机能够识别出图片中的手写数字是 0 - 9 中的哪一个。在这个任务中,我们使用的是 MNIST 数据集,它包含了大量的手写数字图片和对应的标签。下面是使用 Keras 的 Sequential 模型构建手写数字识别神经网络的代码示例:

from keras.models import Sequential

from keras.layers import Dense

from keras.datasets import mnist

from keras.utils import np_utils

# 加载数据

(X_train, y_train), (X_test, y_test) = mnist.load_data()

# 数据预处理

X_train = X_train.reshape(X_train.shape[0], -1) / 255.0

X_test = X_test.reshape(X_test.shape[0], -1) / 255.0

y_train = np_utils.to_categorical(y_train, 10)

y_test = np_utils.to_categorical(y_test, 10)

# 构建模型

model = Sequential()

model.add(Dense(128, activation='relu', input_shape=(784,)))

model.add(Dense(10, activation='softmax'))

在这段代码中,首先导入了需要的库和模块。然后使用mnist.load_data()函数加载 MNIST 数据集,这个数据集会自动分为训练集(X_train, y_train)和测试集(X_test, y_test) 。接下来进行数据预处理,将图片数据进行归一化处理,使其取值范围在 0 - 1 之间,同时将标签数据转换为独热编码的形式,这样更适合模型的训练。

在构建模型部分,创建了一个 Sequential 模型对象model 。然后使用model.add()方法依次添加网络层。第一层是一个全连接层Dense(128, activation='relu', input_shape=(784,)) ,其中128表示这一层有 128 个神经元,activation='relu'表示使用 ReLU 激活函数,input_shape=(784,)表示输入数据的形状是 784 维,这是因为 MNIST 数据集中的每张图片是 28x28 像素的,展开成一维向量后就是 784 维。第二层也是一个全连接层Dense(10, activation='softmax') ,这里的10表示输出层有 10 个神经元,对应 0 - 9 这 10 个数字,activation='softmax'表示使用 softmax 激活函数,它会将输出转换为概率分布,每个神经元的输出值表示对应数字的概率。通过这样的网络结构,模型就可以学习到手写数字的特征,并进行分类预测。

2、函数式模型

函数式模型相比序贯模型更加灵活,它可以构建更复杂的模型结构,比如多输入多输出的模型。当你需要处理一些复杂的任务,序贯模型无法满足需求时,函数式模型就能派上用场。它就像是一个更高级的工具,让你可以根据具体需求自由地设计模型的连接方式。

下面通过一个多输入多输出的模型示例,来展示函数式模型的使用方法和灵活性。假设我们要构建一个模型,它有两个输入,一个输入是文本数据,另一个输入是图像数据,模型的输出有两个,一个是对文本内容的分类结果,另一个是对图像内容的分类结果。代码示例如下:

from keras.layers import Input, Dense, concatenate

from keras.models import Model

# 定义输入层

text_input = Input(shape=(100,))

image_input = Input(shape=(28, 28, 1))

# 对文本输入进行处理

text_features = Dense(64, activation='relu')(text_input)

# 对图像输入进行处理

image_features = Dense(64, activation='relu')(image_input)

# 将文本和图像特征进行合并

merged = concatenate([text_features, image_features])

# 定义输出层

text_output = Dense(10, activation='softmax')(merged)

image_output = Dense(5, activation='softmax')(merged)

# 构建模型

model = Model(inputs=[text_input, image_input], outputs=[text_output, image_output])

在这段代码中,首先使用Input()函数定义了两个输入层,分别是text_inputimage_input ,它们的形状分别是(100,)(28, 28, 1) ,表示文本输入是 100 维的向量,图像输入是 28x28 像素的单通道图像。然后分别对文本输入和图像输入进行处理,通过全连接层提取特征。接着使用concatenate()函数将文本特征和图像特征合并在一起。最后,根据合并后的特征定义了两个输出层,分别是text_outputimage_output ,用于对文本和图像进行分类。通过Model()函数将输入和输出连接起来,构建出了完整的多输入多输出模型。这种灵活的构建方式,使得函数式模型在处理复杂任务时具有很大的优势。

(三)编译与训练

当我们使用 Keras 搭建好模型后,就需要对模型进行编译和训练,这是让模型学习数据特征、具备预测能力的关键步骤。

在模型编译阶段,我们需要设置一些重要的参数,这些参数就像是模型训练的 “指南针”,指引着模型朝着正确的方向学习。首先是优化器,它的作用是调整模型的权重,使得模型在训练过程中能够不断地降低损失函数的值。Keras 提供了多种优化器供我们选择,比如常见的随机梯度下降(SGD)、Adagrad、Adadelta、Adam 等。以 Adam 优化器为例,它结合了 Adagrad 和 Adadelta 的优点,能够自适应地调整学习率,在很多任务中都表现出色。在编译模型时,我们可以这样设置优化器:

from keras.optimizers import Adam

model.compile(optimizer=Adam(lr=0.001),...)

这里lr参数表示学习率,设置为0.001,它决定了模型在训练过程中权重更新的步长,学习率过大可能导致模型无法收敛,学习率过小则会使训练过程变得非常缓慢。

损失函数也是编译模型时必不可少的参数,它用于衡量模型预测值与真实值之间的差异。不同的任务需要选择不同的损失函数。在分类任务中,如果是二分类问题,常用的损失函数是二元交叉熵(binary_crossentropy);如果是多分类问题,比如前面提到的手写数字识别,通常使用分类交叉熵(categorical_crossentropy) 。以手写数字识别模型为例,编译时设置损失函数如下:

model.compile(..., loss='categorical_crossentropy',...)

评估指标用于评估模型在训练和测试过程中的性能表现。在分类任务中,我们通常会关注准确率(accuracy),它表示模型预测正确的样本数占总样本数的比例。在编译模型时,可以将准确率作为评估指标之一:

model.compile(..., metrics=['accuracy'])

完成模型编译后,就可以开始训练模型了。训练模型的过程就像是让模型在数据的海洋中不断学习和成长。在 Keras 中,使用fit()方法进行模型训练,下面是训练手写数字识别模型的代码示例:

model.fit(X_train, y_train, batch_size=128, epochs=10, validation_data=(X_test, y_test))

在这个示例中,X_trainy_train是训练数据和对应的标签,batch_size表示每次训练时使用的样本数量,设置为128,这意味着模型每次从训练数据中取出 128 个样本进行训练,这样可以减少内存的占用,同时也能加快训练速度。epochs表示训练的轮数,这里设置为10,即模型会对整个训练数据进行 10 次训练。validation_data参数用于指定验证数据,这里使用测试集(X_test, y_test)作为验证数据,模型在每一轮训练结束后,都会在验证数据上进行评估,以查看模型的泛化能力是否良好。通过不断地调整这些训练参数,我们可以让模型达到更好的训练效果。

(四)评估与预测

当模型训练完成后,我们就需要对模型的性能进行评估,看看它是否达到了我们的预期。在 Keras 中,使用evaluate()方法可以方便地对训练好的模型进行评估。继续以手写数字识别模型为例,评估代码如下:

score = model.evaluate(X_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

在这段代码中,model.evaluate(X_test, y_test)会返回一个包含损失值和评估指标值的列表。这里的score[0]表示测试集上的损失值,损失值越低,说明模型在测试集上的预测结果与真实值之间的差异越小;score[1]表示测试集上的准确率,准确率越高,说明模型在测试集上的预测效果越好。通过这些评估指标,我们可以直观地了解模型的性能表现。

除了评估模型,我们还可以使用训练好的模型进行预测,让模型对新的数据进行判断和分类。在 Keras 中,使用predict()方法进行预测。例如,我们想要预测测试集中前 10 个样本的数字类别,可以这样做:

predictions = model.predict(X_test[:10])

print(predictions)

model.predict(X_test[:10])会返回一个数组,数组的每一行表示对应样本属于 0 - 9 这 10 个数字类别的概率分布。例如,predictions[0]表示第一个样本属于各个数字类别的概率,我们可以通过argmax()函数找到概率最大的类别索引,从而得到模型的预测结果:

import numpy as np
predicted_classes = np.argmax(predictions, axis=1)
print('Predicted classes:', predicted_classes)

通过这样的方式,我们就可以根据模型的预测结果进行分析和应用。比如在实际的手写数字识别场景中,我们可以将模型集成到一个应用程序中,让用户输入手写数字图片,模型就能快速给出识别结果,为人们的生活和工作带来便利。

结语

今天这篇文章,我们初步学习了 Keras,包括它是什么、具备哪些优势(简洁易用的 API、强大的兼容性、广泛的应用领域),以及基本使用方法。在下一篇文章里,我会带大家了解 Keras 在图像处理与自然语言处理领域的应用案例。欢迎订阅专栏->[传送门],及时获取最新文章。


延伸阅读

深度学习框架探秘|PyTorch:AI 开发的灵动画笔

深度学习框架探秘|TensorFlow:AI 世界的万能钥匙

深度学习框架探秘|TensorFlow vs PyTorch:AI 框架的巅峰对决

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2298826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2.11学习

misc buu-荷兰宽带泄露 下载附件得到了一个后缀为.bin的文件 是宽带数据文件,用RouterPassView工具进行查看。大多数现代路由器都可以让您备份一个文件路由器的配置文件,然后在需要的时候从文件中恢复配置。路由器的备份文件通常包含了像您的ISP的用户…

Python 调用 DeepSeek API 案例详细教程

本案例为以 Python 为例的调用 DeepSeek API 的小白入门级详细教程 步骤 先注册并登录 DeepSeek 官网:https://www.deepseek.com/ 手机号验证码注册或登录即可 创建 API KEY 注意保存,写代码时必须提供的 打开 Pycharm 创建工程 并安装 OpenAI 库编写代…

C++ Primer 函数基础

欢迎阅读我的 【CPrimer】专栏 专栏简介:本专栏主要面向C初学者,解释C的一些基本概念和基础语言特性,涉及C标准库的用法,面向对象特性,泛型特性高级用法。通过使用标准库中定义的抽象设施,使你更加适应高级…

【SVN基础】

软件:ToritoiseSVN 代码版本回退:回退到上一个版本 问题:SVN版本已经提交了版本1和版本2,现在发现不需要版本2的内容,需要回退到版本1然后继续开发。 如图SVN版本已经提交到了107版本,那么本地仓库也已经…

kron积计算mask类别矩阵

文章目录 1. 生成类别矩阵如下2. pytorch 代码3. 循环移动矩阵 1. 生成类别矩阵如下 2. pytorch 代码 import torch import torch.nn as nn import torch.nn.functional as Ftorch.set_printoptions(precision3, sci_modeFalse)if __name__ "__main__":run_code 0…

Stable Diffusion 安装教程(附安装包) 【SD三种安装方式,Win+Mac一篇文章讲明白】

“Stable Diffusion的门槛过高、不会安装?没关系,这篇文章教会你如何安装!” Stable Diffusion的安装部署其实并不困难,只需简单点击几下,几分钟就能安装好,不管是windows还是苹果mac电脑,关于…

网络安全用centos干嘛 网络安全需要学linux吗

网络安全为啥要学Linux系统,据不完全统计,Linux系统在数据中心操作系统上的份额高达70%。它一般运行于服务器和超级计算机上。 所以我们日常访问的网站后台和app后端都是部署在Linux服务器上的,如果你不会Linux系统操作,那么很多…

jupyter notebook中3种读图片的方法_与_图片翻转(上下翻转,左右翻转,上下左右翻转)

已有图片cat.jpg 相对于代码的位置,可以用./cat.jpg进行读取。 下面是3种读图片的方法。 1.python读图片-pillow 图片文件不适合用open去读取 用open读图片,易引发UnicodeDecodeError: gbk codec cant decode byte 0xff in position 0: illegal multib…

微软官方出品GPT大模型编排工具:7个开源项目

今天一起盘点下,12月份推荐的7个.Net开源项目(点击标题查看详情)。 1、一个浏览器自动化操作的.Net开源库 这是一个基于 Google 开源的 Node.js 库 Puppeteer 的 .NET 开源库,方便开发人员使用无头 Web 浏览器抓取 Web、检索 Ja…

机器视觉--Halcon变量的创建与赋值

一、引言 在机器视觉领域,Halcon 作为一款强大且功能丰富的软件库,为开发者提供了广泛的工具和算子来处理各种复杂的视觉任务。而变量作为程序中存储和操作数据的基本单元,在 Halcon 编程中起着至关重要的作用。正确地创建和赋值变量是编写高…

03【FreeRTO队列-如何获取任务信息与队列的动静态创建】

一.利用 vTaskList()以及 vTaskGetRunTimeStats()来获取任务的信息 1.现象与开启启用宏 freeRTOSConfig.h //必须启用 #define configUSE_TRACE_FACILITY 1 #define configGENERATE_RUN_TIME_STATS 1 #define configUSE_STATS_FORMATTING_FUNCTIONS…

GBD研究——美国州级地图(附资源)

美国州级别地图 地图源很多,随便下载。不过我试了两个资源,发现有的资源会漏掉阿拉斯加和夏威夷。 就剩大的这块佩奇 出现这样的问题,要么跟数据源有关,要么就是要掉地名来看,是不是没匹配上。 亲自试过&#xff0c…

【微服务学习一】springboot微服务项目构建以及nacos服务注册

参考链接 3. SpringCloud - 快速通关 springboot微服务项目构建 教程中使用的springboot版本是3.x,因此需要使用jdk17,并且idea也需要高版本,我这里使用的是IDEA2024。 环境准备好后我们就可以创建springboot项目,最外层的项目…

第39周:猫狗识别 2(Tensorflow实战第九周)

目录 前言 一、前期工作 1.1 设置GPU 1.2 导入数据 输出 二、数据预处理 2.1 加载数据 2.2 再次检查数据 2.3 配置数据集 2.4 可视化数据 三、构建VGG-16网络 3.1 VGG-16网络介绍 3.2 搭建VGG-16模型 四、编译 五、训练模型 5.1 上次程序的主要Bug 5.2 修改版…

DeepSeek 概述与本地化部署【详细流程】

目录 一、引言 1.1 背景介绍 1.2 本地化部署的优势 二、deepseek概述 2.1 功能特点 2.2 核心优势 三、本地部署流程 3.1 版本选择 3.2 部署过程 3.2.1 下载Ollama 3.2.2 安装Ollama 3.2.3 选择 r1 模型 3.2.4 选择版本 3.2.5 本地运行deepseek模型 3.3.6 查看…

jenkins war Windows安装

Windows安装Jenkins 需求1.下载jenkins.war2.编写快速运行脚本3.启动Jenkins4.Jenkins使用 需求 1.支持在Windows下便捷运行Jenkins; 2.支持自定义启动参数; 3.有快速运行的脚步样板。 1.下载jenkins.war Jenkins下载地址:https://get.j…

3D打印技术:如何让古老文物重获新生?

如何让古老文物在现代社会中焕发新生是一个重要议题。传统文物保护方法虽然在一定程度上能够延缓文物的损坏,但在文物修复、展示和传播方面仍存在诸多局限。科技发展进步,3D打印技术为古老文物的保护和传承提供了全新的解决方案。我们来探讨3D打印技术如…

Vue h函数到底是个啥?

h 到底是个啥? 对于了解或学习Vue高阶组件(HOC)的同学来说,h() 函数无疑是一个经常遇到的概念。 那么,这个h() 函数究竟如何使用呢,又在什么场景下适合使用呢? 一、h 是什么 看到这个函数你可…

深入浅出 Python Logging:从基础到进阶日志管理

在 Python 开发过程中,日志(Logging)是不可或缺的调试和监控工具。合理的日志管理不仅能帮助开发者快速定位问题,还能提供丰富的数据支持,让应用更具可观测性。本文将带你全面了解 Python logging 模块,涵盖…

Android WindowContainer窗口结构

Android窗口是根据显示屏幕来管理,每个显示屏幕的窗口层级分为37层,0-36层。每层可以放置多个窗口,上层窗口覆盖下面的。 要理解窗口的结构,需要学习下WindowContainer、RootWindowContainer、DisplayContent、TaskDisplayArea、T…