深度学习框架Keras与Pytorch对比

news2025/1/23 2:12:25

对于许多科学家、工程师和开发人员来说,TensorFlow是他们的第一个深度学习框架。TensorFlow 1.0于2017年2月发布,可以说,它对用户不太友好。

在过去的几年里,两个主要的深度学习库KerasPytorch获得了大量关注,主要是因为它们的使用比较简单。

本文将介绍Keras与Pytorch的4个不同点以及为什么选择其中一个库的原因。

Keras

Keras本身并不是一个框架,而是一个位于其他深度学习框架之上的高级API。目前它支持TensorFlow、Theano和CNTK。

Keras的优点在于它的易用性。这是迄今为止最容易上手并快速运行的框架。定义神经网络是非常直观的,因为使用API可以将层定义为函数。

Pytorch

Pytorch是一个深度学习框架(类似于TensorFlow),由Facebook的人工智能研究小组开发。与Keras一样,它也抽象出了深层网络编程的许多混乱部分。

就高级和低级代码风格而言,Pytorch介于Keras和TensorFlow之间。比起Keras具有更大的灵活性和控制能力,但同时又不必进行任何复杂的声明式编程(declarative programming)。

深度学习的从业人员整天都在纠结应该使用哪个框架。一般来说,这取决于个人喜好。但是在选择Keras和Pytorch时,你应该记住它们的几个方面。

640?wx_fmt=jpeg

(1)定义模型的类与函数

为了定义深度学习模型,Keras提供了函数式API。使用函数式API,神经网络被定义为一系列顺序化的函数,一个接一个地被应用。例如,函数定义层1( function defining layer 1)的输出是函数定义层2的输入。

img_input = layers.Input(shape=input_shape)
x = layers.Conv2D(64, (3, 3), activation='relu')(img_input)    
x = layers.Conv2D(64, (3, 3), activation='relu')(x)    
x = layers.MaxPooling2D((2, 2), strides=(2, 2))(x)

在Pytorch中,你将网络设置为一个继承来自Torch库的torch.nn.Module的类。与Keras类似,Pytorch提供给你将层作为构建块的能力,但是由于它们在Python类中,所以它们在类的init_ ()方法中被引用,并由类的forward()方法执行。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.conv2 = nn.Conv2d(64, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        return x
model = Net()

(2)张量和计算图模型与标准数组的比较

Keras API向普通程序员隐藏了许多混乱的细节。这使得定义网络层是直观的,并且默认的设置通常足以让你入门。

只有当你正在实现一个相当先进或“奇特”的模型时,你才真正需要深入了解底层,了解一些基本的TensorFlow。

棘手的部分是,当你真正深入到较低级别的TensorFlow代码时,所有的挑战就随之而来!你需要确保所有的矩阵乘法都对齐。不要试着想打印出你自己定义的层的输出,因为你只会得到一个打印在你的终端上的没有错误的张量定义。

Pytorch在这些方面更宽容一些。你需要知道每个层的输入和输出大小,但是这是一个比较容易的方面,你可以很快掌握它。你不需要构建一个抽象的计算图,避免了在实际调试时无法看到该抽象的计算图的细节。

Pytorch的另一个优点是平滑性,你可以在Torch张量和Numpy数组之间来回切换。如果你需要实现一些自定义的东西,那么在TF张量和Numpy数组之间来回切换可能会很麻烦,这要求开发人员对TensorFlow会话有一个较好的理解。

Pytorch的互操作实际上要简单得多。你只需要知道两种操作:一种是将Torch张量(一个可变对象)转换为Numpy,另一种是反向操作。

当然,如果你从来不需要实现任何奇特的东西,那么Keras就会做得很好,因为你不会遇到任何TensorFlow的障碍。但是如果你有这个需求,那么Pytorch将会是一个更加好的选择。

(3)训练模型

640?wx_fmt=jpeg

用Keras训练模特超级简单!只需一个简单的.fit(),你就可以直接去跑步了。

history = model.fit_generator(
    generator=train_generator,
    epochs=10,
    validation_data=validation_generator)

在Pytorch中训练模型包括以下几个步骤:

  1. 在每批训练开始时初始化梯度
  2. 前向传播
  3. 反向传播
  4. 计算损失并更新权重
# 在数据集上循环多次
for epoch in range(2):  
    for i, data in enumerate(trainloader, 0):
        # 获取输入; data是列表[inputs, labels]
        inputs, labels = data 
        # (1) 初始化梯度
        optimizer.zero_grad() 

        # (2) 前向传播
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        # (3) 反向传播
        loss.backward()
        # (4) 计算损失并更新权重
        optimizer.step()

光是训练就需要很多步骤!

我想这种方式你就会知道实际上发生了什么。由于这些模型训练步骤对于训练不同的模型本质上保持不变,所以这些代码实际上完全不必要的。

(4)控制CPU与GPU模式的比较

640?wx_fmt=jpeg

如果你已经安装了tensorflow-gpu,那么在Keras中使用GPU是默认启用和完成的。如果希望将某些操作转移到CPU,可以使用以下代码。

with tf.device('/cpu:0'):
    y = apply_non_max_suppression(x)

对于Pytorch,你必须显式地为每个torch张量和numpy变量启用GPU。这将使代码变得混乱,如果你在CPU和GPU之间来回移动以执行不同的操作,则很容易出错。

例如,为了将我们之前的模型转移到GPU上运行,我们需要做以下工作:

#获取GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#传送网络到GPU
net.to(device)

# 传送输入和标签到GPU
inputs, labels = data[0].to(device), data[1].to(device)

JAVASCRIPT 复制 全屏

Keras在这方面的优势在于它的简单性和良好的默认设置

选择框架的一般建议

我通常给出的建议是从Keras开始。

Keras绝对是最容易使用、理解和快速上手并运行的框架。你不需要担心GPU设置,处理抽象代码,或者做任何复杂的事情。你甚至可以在不接触TensorFlow的任何一行的情况下实现定制层和损失函数。

如果你确实开始深入到深度网络的更细粒度方面,或者正在实现一些非标准的东西,那么Pytorch就是你的首选库。在Keras上实现反而会有一些额外的工作量,虽然不多,但这会拖慢你的进度。使用pytorch能够快速地实现、训练和测试你的网络,并附带易于调试的额外好处!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1343855.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java EE初阶五】wait及notify关键字

1. wait和notify的概念 所谓的wait和notify其实就是等待、通知机制;该机制的作用域join类似;由于多个线程之间是随机调度的,引入wait和notify就是为了能够从应用层面上,干预到多个不同线程代码的执行顺序,此处的干预&a…

C# WPF上位机开发(Web API联调)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 很多时候,客户需要开发的不仅仅是一个上位机系统,它还有其他很多配套的系统或设备,比如物流小车、立库、数字孪…

web前端开发html/css求职简介/个人简介小白网页设计

效果图展示&#xff1a; html界面展示&#xff1a; html/css代码&#xff1a; <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns"http://www.w3.…

Java IDEA JUnit 单元测试

JUnit是一个开源的 Java 单元测试框架&#xff0c;它使得组织和运行测试代码变得非常简单&#xff0c;利用JUnit可以轻松地编写和执行单元测试&#xff0c;并且可以清楚地看到哪些测试成功&#xff0c;哪些失败 JUnit 还提供了生成测试报告的功能&#xff0c;报告不仅包含测试…

VSCode + vite + vue3断点调试配置

没想到这个配置我搞了一上午&#xff0c;网上很多的配置方案都没有效果。总算搞定了&#xff0c;特此记录一下。 首先需要在.vscode文件夹下面创建launch.json配置文件。然后输入如下配置&#xff1a; {// 使用 IntelliSense 了解相关属性。 // 悬停以查看现有属性的描述。//…

Java Swing GUI实现ATM机(涉及网络编程聊天功能)

一、序言 1.首先这是本人大二时期的编程&#xff0c;涉及到网络编程的聊天功能&#xff0c;大佬勿喷。 二、且看展示图片 1.首先启动服务端&#xff08;启动Fuwuduan代码&#xff09;&#xff0c;也就是客服聊天窗口 提供给用户申请银行卡号&#xff0c;客服界面如下&#x…

复试 || 就业day01(2023.12.29)项目一

文章目录 前言正规方程二元一次示例正规方程 : w ( X T X ) − 1 X T y w (X^TX)^{-1}X^Ty w(XTX)−1XTy三元一次方程示例八元一次方程示例sklearn带截距的线性方程总结 前言 &#x1f4ab;你好&#xff0c;我是辰chen&#xff0c;本文旨在准备考研复试或就业 &#x1f4ab;…

unity exe程序置顶和全屏

1.置顶和无边框 设置显示位置和范围 using System; using System.Runtime.InteropServices; using UnityEngine; public class WindowMod : MonoBehaviour {public enum appStyle{FullScreen,WindowedFullScreen,Windowed,WindowedWithoutBorder}public enum zDepth{Normal…

手写Spring与基本原理--简易版

文章目录 手写Spring与基本原理解析简介写一个简单的Bean加载容器定义一个抽象所有类的BeanDefinition定义一个工厂存储所有的类测试 实现Bean的注册定义和获取基于Cglib实现含构造函数的类实例化策略Bean对象注入属性和依赖Bean的功能Spring.xml解析和注册Bean对象实现应用上下…

STM32CubeMX学习(二) USB CDC 双向通信

STM32CubeMX学习&#xff08;二&#xff09; USB CDC 双向通信 简介CubeMX新建工程&#xff08;串口LED&#xff09;测试串口和LED串口接收测试USB CDC通信 简介 利用正点原子F407探索者开发板&#xff0c;测试基于USB CDC的双向数据通信。 CubeMX新建工程&#xff08;串口LE…

ES6+ 面试常问题

一、let const var 的区别 1. var&#xff1a; 没有块级作用域的概念&#xff0c;有函数作用域和全局作用域的概念全局作用域性下创建变量会被挂在到 windows 上存在变量提升同一作用域下&#xff0c;可以重复赋值创建未初始化&#xff0c;值为 undefined 2. let&#xff1a…

2023年末,软件测试面试题总结与分享

大家好&#xff0c;最近有不少小伙伴在后台留言&#xff0c;得准备年后面试了&#xff0c;又不知道从何下手&#xff01;为了帮大家节约时间&#xff0c;特意准备了一份面试相关的资料&#xff0c;内容非常的全面&#xff0c;真的可以好好补一补&#xff0c;希望大家在都能拿到…

天擎终端安全管理系统clientinfobymid存在SQL注入漏洞

产品简介 奇安信天擎终端安全管理系统是面向政企单位推出的一体化终端安全产品解决方案。该产品集防病毒、终端安全管控、终端准入、终端审计、外设管控、EDR等功能于一体&#xff0c;兼容不同操作系统和计算平台&#xff0c;帮助客户实现平台一体化、功能一体化、数据一体化的…

《PCI Express体系结构导读》随记 —— 第I篇 第1章 PCI总线的基本知识(16)

接前一篇文章&#xff1a;《PCI Express体系结构导读》随记 —— 第I篇 第1章 PCI总线的基本知识&#xff08;15&#xff09; 1.3 PCI总线的存储器读写总线事务 1.3.5 Delayed传送方式 如前文所述&#xff0c;当处理器使用Non-Posted总线周期对PCI设备进行操作、或者PCI设备使…

Android MVVM 写法

前言 Model&#xff1a;负责数据逻辑 View&#xff1a;负责视图逻辑 ViewModel&#xff1a;负责业务逻辑 持有关系&#xff1a; 1、ViewModel 持有 View 2、ViewModel 持有 Model 3、Model 持有 ViewModel 辅助工具&#xff1a;DataBinding 执行流程&#xff1a;View &g…

linux源码编译升级安装openssl3.0.1导致系统启动失败的问题解决

前两天在安装curl的时候&#xff0c;提示openssl版本太老了&#xff0c;原有的版本是openssl1.0的版本&#xff0c;需要将其升级到openssl3的版本。 直接使用命令行sudo apt install默认安装的还是openssl1.1.1版本&#xff0c;因此决定使用源码自行安装。 具体的安装过程就不赘…

webpack打包批量替换路径(string-replace-webpack-plugin插件)

string-replace-webpack-plugin 是一个用于在 webpack 打包后的文件中替换字符串的插件。它可以用于将特定字符串替换为其他字符串&#xff0c;例如将敏感信息从源代码中移除或对特定文本进行本地化处理。比如文件的html、css、js中的路径地址想批量更改一下 http://localhost:…

海德堡UV灯电源维修eta Plus Elc PE22-400-210

uv灯电源维修故障包括&#xff1a; 1、电压不稳&#xff1a;检查uv打印机的电压&#xff0c;设置一个稳压箱即可。 2、温度过高&#xff1a;uv打印机温度过高也会影响uv灯&#xff0c;可以更换为水冷式循环降温。 3、水箱里的信号线接触不好&#xff1a;将两边的信号线对调&…

leetcode刷题记录07(2023-04-30)【二叉树展开为链表 | 买卖股票的最佳时机 | 二叉树中的最大路径和(递归) | 最长连续序列(并查集)】

114. 二叉树展开为链表 给你二叉树的根结点 root &#xff0c;请你将它展开为一个单链表&#xff1a; 展开后的单链表应该同样使用 TreeNode &#xff0c;其中 right 子指针指向链表中下一个结点&#xff0c;而左子指针始终为 null 。 展开后的单链表应该与二叉树 先序遍历 顺…

ArcGIS批量计算shp面积并导出shp数据总面积(建模法)

在处理shp数据时&#xff0c; 又是我们需要知道许多个shp字段的批量计算&#xff0c;例如计算shp的总面积、面积平均值等&#xff0c;但是单个查看shp文件的属性进行汇总过于繁琐&#xff0c;因此可以借助建模批处理来计算。 首先准备数据&#xff1a;一个含有多个shp的文件夹。…