详解:Tensorflow、Pytorch、Keras

news2025/1/18 3:30:29

这是一个专门对Tensorflow、Pytorch、Keras三个主流DL框架的一个详解和对比分析

一、何为深度学习框架?

你可以理解为一个工具帮你构建一个深度学习网络,调用里面的各种方法就能自行构建任意层,diy你想要的DNN,而且任意指定学习器和优化器等,非常的方便!

二、Tensorflow

1.发展历史

TensorFlow由Google智能机器研究部门Google Brain团队研发的;TensorFlow编程接口支持Python和C++。随着1.0版本的公布,相继支持了Java、Go、R和Haskell API的alpha版本。

在2017年,Tensorflow独占鳌头,处于深度学习框架的领先地位;但截至目前已经和Pytorch不争上下。

注意,Tensorflow目前主要在工业级领域处于领先地位。参考至博客(38 封私信 / 16 条消息) 为什么说学术上用pytorch,工业上用tensorflow? - 知乎 (zhihu.com)

但说句实话,这个问题过于宏观,每个人都有自己的观点,最好还是自己实际两者都使用之后,再来说最适合自己的是哪一个吧。(并且tensoeflow和pytorch两者都一直在发展,后期有可能就不分伯仲啦!)

三、Pytorch

Pytorch目前是由Facebook人工智能学院提供支持服务的。

Pytorch目前主要在学术研究方向领域处于领先地位

 

其优点在于:PyTorch可以使用强大的GPU加速的Tensor计算(比如:Numpy的使用)以及可以构建带有autograd的深度神经网络。

同时,PyTorch 的代码很简洁、易于使用、支持计算过程中的动态图而且内存使用很高效

四、Keras

本来是一个独立的高级API,现在已经成为Tensorflow的一部分

接口简单友好,使用tensorflow作为后端,适合快速实验和原型开发。

五、区别

主要区别:

  1. 计算图:

    • TensorFlow使用静态计算图,需要先定义后运行
    • PyTorch使用动态计算图,更灵活,可以边定义边运行
  2. 易用性:

    • Keras通常被认为是最容易上手的
    • PyTorch的API设计更加直观
    • TensorFlow相对复杂一些,但提供更多底层控制
  3. 性能和部署:

    • TensorFlow在大规模部署和性能优化方面较为成熟
    • PyTorch在研究和实验阶段更受欢迎
    • Keras作为高级API,性能可能略低,但开发速度快
  4. 社区和生态系统:

    • TensorFlow拥有最大的社区和最广泛的工具支持
    • PyTorch在学术界更受欢迎,增长迅速
    • Keras作为TensorFlow的一部分,也有很好的社区支持

六、总结

  1. TensorFlow:

    • 由Google开发
    • 使用静态计算图
    • 广泛应用于生产环境
    • 有较为完善的部署工具
  2. PyTorch:

    • 由Facebook开发
    • 使用动态计算图
    • 更加灵活,适合研究和快速原型开发
    • 相对更加直观和易于调试
  3. Keras:

    • 最初是一个独立的高级API,现已成为TensorFlow的一部分
    • 提供更简单、更用户友好的接口
    • 可以使用TensorFlow或Theano作为后端
    • 适合快速实验和原型开发

六、代码实现以及对比(key😍)

选择哪个框架通常取决于项目需求、个人偏好和团队经验。对于初学者,Keras可能是最好的起点对于需要更多控制和自定义的高级用户,PyTorch或TensorFlow的低级API可能更合适。

使用TensorFlow、PyTorch和Keras分别搭建一个简单的深度神经网络的例子。这些例子都将创建一个简单的前馈神经网络用于MNIST手写数字分类任务。

       1.TensorFlow 2.x 示例:(这里用的是低级API,没用tf.keras的高级API)

  • 丰富的低级操作:允许对计算过程进行精细控制。
  • 强大的性能优化:特别是在分布式和大规模部署方面。
  • 全面的工具生态系统:包括TensorBoard、TFLite等工具。
  • 灵活的模型部署:支持多种平台和设备。
  • 静态图支持:虽然2.x版本默认使用动态图,但仍支持静态图,有利于某些优化。

同时下面这个例子展示了TensorFlow低级API的几个关键特性:

  1. 手动定义模型参数(W1, b1, W2, b2)作为tf.Variable。

  2. 使用函数定义模型结构,而不是使用Keras的Sequential或Functional API。

  3. 自定义损失函数。

  4. 使用tf.GradientTape来计算梯度。

  5. 手动应用梯度到优化器。

  6. 使用@tf.function装饰器来将Python函数转换为TensorFlow图,以提高性能。

  7. 手动实现训练循环和评估过程。

  8. 使用tf.data.Dataset API来处理数据。

import tensorflow as tf
import numpy as np

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 将数据转换为适当的形状和类型
x_train = x_train.reshape(-1, 784).astype(np.float32)
x_test = x_test.reshape(-1, 784).astype(np.float32)
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)

# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# 定义模型参数
W1 = tf.Variable(tf.random.normal([784, 128], stddev=0.1))
b1 = tf.Variable(tf.zeros([128]))
W2 = tf.Variable(tf.random.normal([128, 10], stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))

# 定义模型函数
def model(x):
    layer1 = tf.nn.relu(tf.matmul(x, W1) + b1)
    return tf.matmul(layer1, W2) + b2

# 定义损失函数
def loss_fn(predictions, labels):
    return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=predictions))

# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate=0.001)

# 定义训练步骤
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        predictions = model(x)
        loss = loss_fn(predictions, y)
    gradients = tape.gradient(loss, [W1, b1, W2, b2])
    optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2]))
    return loss

# 定义测试步骤
@tf.function
def test_step(x, y):
    predictions = model(x)
    loss = loss_fn(predictions, y)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(predictions, axis=1), y), tf.float32))
    return loss, accuracy

# 训练循环
epochs = 5
for epoch in range(epochs):
    total_loss = 0.0
    num_batches = 0
    for x_batch, y_batch in train_dataset:
        loss = train_step(x_batch, y_batch)
        total_loss += loss
        num_batches += 1
    avg_loss = total_loss / num_batches
    
    # 在测试集上评估
    test_loss = 0.0
    test_accuracy = 0.0
    num_test_batches = 0
    for x_test_batch, y_test_batch in test_dataset:
        batch_loss, batch_accuracy = test_step(x_test_batch, y_test_batch)
        test_loss += batch_loss
        test_accuracy += batch_accuracy
        num_test_batches += 1
    avg_test_loss = test_loss / num_test_batches
    avg_test_accuracy = test_accuracy / num_test_batches
    
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {avg_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Test Accuracy: {avg_test_accuracy:.4f}")

# 最终模型评估
final_test_loss = 0.0
final_test_accuracy = 0.0
num_final_batches = 0
for x_test_batch, y_test_batch in test_dataset:
    batch_loss, batch_accuracy = test_step(x_test_batch, y_test_batch)
    final_test_loss += batch_loss
    final_test_accuracy += batch_accuracy
    num_final_batches += 1
final_avg_test_loss = final_test_loss / num_final_batches
final_avg_test_accuracy = final_test_accuracy / num_final_batches

print("\nFinal Test Results:")
print(f"Test Loss: {final_avg_test_loss:.4f}, Test Accuracy: {final_avg_test_accuracy:.4f}")

     2. PyTorch示例

PyTorch的API设计更加直观

  • 类似Python的编程风格:使用动态计算图,编码感觉更像普通Python编程。
  • 面向对象的设计:模型定义为类,更符合Python用户的习惯。
  • 即时执行:可以立即看到每一步的结果,便于调试。
  • 灵活性:易于处理动态网络结构和复杂的研究型模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 加载数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

model = Net()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# 训练模型
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# 运行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(1, 6):
    train(model, device, train_loader, optimizer, epoch)

# 评估模型
model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

print(f'Accuracy: {correct / len(test_loader.dataset)}')

   3. Keras示例

  • 高级API:Keras提供了非常简洁和直观的API,隐藏了许多底层复杂性。
  • 模块化设计:可以轻松堆叠层来构建模型,如model.add(layer)。
  • 内置常用模型:提供了许多预定义的架构,如Sequential模型。
  • 一致的接口:无论后端如何(TensorFlow、Theano等),接口保持一致。
  • 详细文档:有优秀的文档和大量教程。

比较简单,加载数据---构建模型(tf.keras自己叠就行)---编译模型(定义模型在训练过程中如何学习,使用什么优化器,使用什么损失函数评估模型性能,以及监控指标等)---训练模型---评估模型。

是不是非常的简单,结构清洗明了,确实在工程上是非常的适合的,搭建快速,便于部署

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 构建模型
model = keras.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(10, activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5)

# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc}')

很明显:

  • Keras适合快速原型设计和简单项目,学习曲线最平缓。
  • PyTorch在研究和复杂模型开发中很受欢迎,因为它的直观性和灵活性。
  • TensorFlow提供了从高级(Keras API)到低级的全方位控制,适合各种规模的项目,尤其是大规模部署。

这些示例展示了使用不同框架构建简单神经网络的基本步骤。每个框架都有其独特的语法和风格,但基本概念是相似的:

  1. 加载和预处理数据
  2. 定义模型结构
  3. 指定损失函数和优化器
  4. 训练模型
  5. 评估模型性能

需要注意的是,Keras现在是TensorFlow的一部分,所以TensorFlow和Keras的例子看起来非常相似。结合了keras的TensorFlow搭建DNN非常的简单(哥们当年用的LSTM就是用的TensorFlow搭建的),PyTorch的例子稍微复杂一些,因为它提供了更多的底层控制。(比如那个transformer的搭建就是依赖的PyTorch,比较复杂)

这些只是基本示例,实际应用中可能需要更复杂的模型结构、数据处理和训练过程。如果您需要更深入的解释或有任何问题,请随时询问。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2141762.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用Qt 对接‌百度语音识别接口

一 、前期准备工作 1,搭建好开发环境; 2,注册百度云平台,获取语音相关东西, 短语音识别标准版_短语音识别-百度AI开放平台 (baidu.com) 3,涉及到的Qt 类有 QAudioFormat,QAudioDeviceInfo&a…

JDBC实现对单表数据增、删、改、查

文章目录 API介绍获取 Statement 对象Statement的API介绍使用步骤案例代码 JDBC实现对单表数据查询ResultSet的原理ResultSet获取数据的API使用JDBC查询数据库中的数据的步骤案例代码 API介绍 获取 Statement 对象 在java.sql.Connection接口中有如下方法获取到Statement对象…

线程池是啥有啥用,怎么用,如何自己实现一个

目录 一、线程池是啥,有啥用 二、线程池怎么用 1.构造方法 2.如何使用Java的线程池 三、简单实现一个线程池 假设我是一个(好看有才华) 的妹子,那么我就会有很多追求者,这些也叫备胎们,我们若把他…

71、哪吒开发板试用结合oak深度相机进行评测

基本思想:收到intel的开发板-小挪吒,正好手中也有oak相机,反正都是openvino一套玩意,进行评测一下,竟然默认是个window系统,哈哈

STL-vector练习题

118. 杨辉三角 思路: 杨辉三角有以下性质使我们要用到的: ● 每行数字左右对称,由 1 开始逐渐变大再变小,并最终回到 1。 ● 第 n 行(从 0 开始编号)的数字有 n1 项,前 n 行共有 2n(n1)个数。…

linux重要文件

/etc/sysconfig/network-scripts/ifcfg-eth1 网卡重启 /etc/init.d/network restart ifup ethname & ifdown ethname /etc/resolv.conf 设置Linux本地的客户端DNS的配置文件 linux客户端DNS可以在网卡配置文件(/etc/sysconfig/network/ifcfg-eth0 DNS2)里配置 也可以在/et…

SSY20240916提高组T1题解__贪心+大模拟

题面 题面描述 fe和xt在玩一个游戏, 在 n m n\times m nm的网格图上进行. 定义 ( a , b ) , ( c , d ) (a,b)\;,\;(c,d) (a,b),(c,d)见距离为 ∣ a − c ∣ ∣ b − d ∣ |a-c||b-d| ∣a−c∣∣b−d∣ 现在游戏按照以下步骤进行: xt选择 k k k个格子fe选择一个格子(不能选…

QT + WebAssembly + Vue环境搭建

Qt6.7.2安装工具 emsdk安装 git clone https://github.com/emscripten-core/emsdk.git cd emsdk emsdk install 3.1.50 emsdk activate 3.1.50 Qt Creator配置emsdk 效果 参考 GitHub - BrockReece/vue-wasm: Vue web assembly loader Emscripten cmake多版本编译-CSDN博客 …

【数据结构】排序算法---希尔排序

文章目录 1. 定义2. 算法步骤3. 动图演示4. 性质5. 算法分析6. 代码实现C语言PythonJavaCGo 结语 1. 定义 希尔排序(英语:Shell sort),也称为缩小增量排序法,是[直接插入排序]的一种改进版本。希尔排序以它的发明者希…

优化最长上升子序列

前言&#xff1a;平时我们做的题目都是用动态规划做的&#xff0c;但是有没有能够优化一下呢&#xff1f; 有一个结论&#xff0c;长度为 i 的一个序列&#xff0c;最后一个元素一定是构成长度为 i 的序列中最小的 我们可以用二分来优化 题目地址 #include<bits/stdc.h>…

【设计模式】创建型模式(四):建造者模式

创建型模式&#xff08;四&#xff09;&#xff1a;建造者模式 1.概念2.案例3.优化 1.概念 建造者模式 是一种创建型设计模式&#xff0c;它允许你创建复杂对象的步骤与表示方式相分离。 建造者模式是一种创建型设计模式&#xff0c;它的主要目的是将一个复杂对象的 构建过程…

极速上云2.0范式:一键智连阿里云

在传统上云的现状与挑战&#xff1a; 专线上云太重&#xff0c;VPN上云不稳&#xff0c;云上VPC&#xff0c;云下物理网络&#xff0c;多段最后一公里...... 层层对接&#xff0c;跳跳延迟&#xff0c;好生复杂! 当你试图理解SD-WAN供应商和阿里云的文档&#xff0c;以协调路由…

7-ZIP工具的功能分享:合并分卷压缩文件

在日常工作中&#xff0c;有些大文件无法单独传输&#xff0c;我们通常会通过压缩拆分成多个分卷文件来完成传输。 当完成传输后&#xff0c;不想要这么多分卷文件的时候&#xff0c;就可以通过7-ZIP工具的合并功能来解决这个问题。下面一起来看看&#xff0c;具体如何操作。 …

【C++算法】位运算

位运算基础知识 1.基础运算符 << : 左移 >> : 右移 ~ : 取反 & : 按位与&#xff0c;有0就是0 I : 按位或&#xff0c;有1就是1 ^ : 按位异或&#xff0c;&#xff08;1&#xff09;相同为0&#xff0c;相异为1&#xff08;2&#xff09;无进位相加 2.…

【docker】阿里云使用docker,2024各种采坑

▒ 目录 ▒ &#x1f6eb; 导读需求开发环境 1️⃣ dial tcp: lookup on 8.8.8.8:53: no such host失败属于DNS问题 2️⃣ docker镜像配置配置最新镜像源 3️⃣ 【重点】阿里云专用获取自己的镜像加速器地址配置镜像地址 &#x1f6ec; 文章小结&#x1f4d6; 参考资料 &#x…

MySQL_SQLYog简介、下载及安装(超详细)

课 程 推 荐我 的 个 人 主 页&#xff1a;&#x1f449;&#x1f449; 失心疯的个人主页 &#x1f448;&#x1f448;入 门 教 程 推 荐 &#xff1a;&#x1f449;&#x1f449; Python零基础入门教程合集 &#x1f448;&#x1f448;虚 拟 环 境 搭 建 &#xff1a;&#x1…

如何设置xshell关闭最后一个选项卡标签时不退出软件?

不知道你是否遇到这个问题&#xff0c;就是在使用xshell的时候&#xff0c;每次关闭最后一个选项卡标签的时候&#xff0c;xshell软件默认就退出了&#xff0c;好多次我都只是想要关闭&#xff0c;而非退出&#xff0c;所以该如何设置&#xff0c;才能到我们的预期的效果呢&…

re题(23)BUUFCTF-[FlareOn4]login

BUUCTF在线评测 (buuoj.cn) 下载后打开看到是一个txt和一个html 分别打开看看&#xff0c;txt是提示&#xff0c;html应该就是要破解的网页 打开网页&#xff0c;查看源代码 找到程序&#xff0c;变灰的部分是关键&#xff0c;是指如果是前13个字母就加13&#xff0c;如果是…

手机端跑大模型:Ollma/llama.cpp/vLLM 实测对比

昨天给大家分享了&#xff1a;如何在手机端用 Ollama 跑大模型 有小伙伴问&#xff1a;为啥要选择 Ollama&#xff1f; 不用 Ollama&#xff0c;还能用啥&#xff1f;据猴哥所知&#xff0c;当前大模型加速的主流工具有&#xff1a;Ollama、vLLM、llama.cpp 等。 那我到底该…

鸿蒙版 React Native 正式开源,ohos_react_native 了解一下

距离鸿蒙 Next 宣布一年后&#xff0c;除了 Flutter 的鸿蒙支持之外&#xff0c;React Native 的社区支持的 ohos_react_native 也终于在 OpenHarmony-SIG 对外开源&#xff0c;并且和 Flutter 不同在于&#xff0c;本次开源的版本是基于 React Native 0.72.5 。 ohos_react_n…