深度学习框架:Pytorch与Keras的区别与使用方法

news2024/11/20 1:39:42

  

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

Pytorch与Keras介绍

Pytorch

模型定义

模型编译

模型训练

输入格式

完整代码

Keras

模型定义

模型编译

模型训练

输入格式

完整代码

区别与使用场景

结语


Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧

Pytorch

模型定义

我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数

import torch
import torch.nn as nn


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.Sigmoid(x)
        return x

模型编译

我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些

import torch.optim as optim


# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()

# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率

模型训练

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值

# 训练模型
epochs = 100

for epoch in range(epochs):
    # 前向传播
    output = model(input_data)

    # 计算损失
    loss = criterion(output, target_data)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复

输入格式

关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢

data = torch.Tensor([[1], [2], [3]])

很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征

我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错

完整代码

import torch
import torch.nn as nn
import torch.optim as optim


# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x



model = SimpleNet()
criterion = nn.MSELoss()

# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值

# 训练模型
epochs = 100

for epoch in range(epochs):
    # 前向传播
    output = model(input_data)

    # 计算损失
    loss = criterion(output, target_data)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


data = torch.Tensor([[1], [2], [3]])
prediction = model(data)

print(prediction)

可以看到模型输出了三个预测值

注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法

Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍

模型定义

from keras.models import Sequential
from keras.layers import Dense


model = Sequential([
    Dense(32, input_dim=1, activation='relu'),
    Dense(1, activation='sigmoid')
])

注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的 

模型编译

那么在Keras中模型又是怎么编译的呢

model.compile(loss='mse', optimizer='sgd')

非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降

模型训练

模型的训练也非常简单

# 训练模型
model.fit(input_data, target_data, epochs=100)

 因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了

输入格式

对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错

data = np.array([[1], [2], [3]])

完整代码

from keras.models import Sequential
from keras.layers import Dense
import numpy as np


# 定义模型
model = Sequential([
    Dense(32, input_dim=1, activation='relu'),
    Dense(1, activation='sigmoid')
])

# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1)  # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1)  # 100个样本,每个样本有5个目标值

# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)

data = np.array([[1], [2], [3]])

prediction = model(data)
print(prediction)

可以看到,同样的任务,Keras的代码量小很多

区别与使用场景

Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计

而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解

结语

Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

 

感谢阅读,觉得有用的话就订阅下本专栏吧 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1269671.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GoLong的学习之路,进阶,Redis

这个redis和上篇rabbitMQ一样,在之前我用Java从原理上进行了剖析,这里呢,我做项目的时候,也需要用到redis,所以这里也将去从怎么用的角度去写这篇文章。 文章目录 安装redis以及原理redis概念redis的应用场景有很多red…

机器学习笔记 - 3D数据的常见表示方式

一、简述 从单一角度而自动合成3D数据是人类视觉和大脑的基本功能,这对计算机视觉算法来说是比较难的。但随着LiDAR、RGB-D 相机(RealSense、Kinect)和3D扫描仪等3D传感器的普及和价格的降低,3D 采集技术的最新进展取得了巨大飞跃。与广泛使用的 2D 数据不同,3D 数据具有丰…

C# 友元程序集

1.友元程序集 使用友元程序集可以将internal成员提供给其他的友元程序集访问。 程序集FriendTest1.dll [assembly:InternalsVisibleTo("FriendTest2")] namespace FriendTest1 {internal class Friend{string name;public string Name > name;public Friend(str…

删除list中除最后一个之外所有的数据

1.你可以新建一个list List<Integer> listnew ArrayList<>();int i0;while (i<100){list.add(i);}List<Integer> subList list.subList(list.size()-1, list.size());System.out.println("原list大小--"list.size());System.out.println("…

golang channel执行原理与代码分析

使用的go版本为 go1.21.2 首先我们写一个简单的chan调度代码 package mainimport "fmt"func main() {ch : make(chan struct{})go func() {ch <- struct{}{}ch <- struct{}{}}()fmt.Println("xiaochuan", <-ch)data, ok : <-chfmt.Println(&…

基础算法学习

文章目录 快速排序归并排序二分浮点数二分 高精度BigIntegerBigDecimal 前缀和差分双指针位运算离散化区间合并 快速排序 确定分界点x &#xff08;可以是左边界&#xff0c;右边界&#xff0c;中间随机&#xff09;将小于等于x的数放到左边&#xff0c;大于等于x的放右边递归…

Intellij idea 内存不够用了,怎么处理?

目录 如何判断内存不够用了 下面演示一下如何开启内存指示器&#xff08;Memory Indicator&#xff09; 解决方案 第一种&#xff1a;双击"内存指示器(Mempory Indicator)" 第二种&#xff1a;增大Intellij Idea 最大可使用内存 如何判断内存不够用了 运行项目后…

ExoPlayer - Failed to initialize OMX.qcom.video.decoder.avc

人莫鉴于流水而鉴于止水&#xff0c;唯止能止众止 1. 背景 使用ExoPlayer&#xff0c;我不信你没遇到过这个问题&#xff1a; java.lang.IllegalArgumentException: Failed to initialize OMX.qcom.video.decoder.avc 详细内容如下图所示&#xff1a; 2. MediaCodec(解码器) …

渲染到纹理:原理及WebGL实现

这篇文章是WebGL系列的延续。 第一个是从基础知识开始的&#xff0c;上一个是向纹理提供数据。 如果你还没有阅读过这些内容&#xff0c;请先查看它们。 NSDT在线工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - …

快速入门opencv(python版)

Open Source Computer Vision Library。OpenCV是一个&#xff08;开源&#xff09;发行的跨平台计算机视觉库&#xff0c;可以运行在Linux、Windows和Mac OS操作系统上。它轻量级而且高效——由一系列 C 函数和少量 C 类构成&#xff0c;同时提供了Python、Ruby、MATLAB等语言的…

PHP微信UI在线聊天系统源码 客服私有即时通讯系统 附安装教程

DuckChat是一套完整的私有即时通讯解决方案&#xff0c;包含服务器端程序和各种客户端程序&#xff08;包括iOS、Android、PC等&#xff09;。通过DuckChat&#xff0c;站点管理员可以快速在自己的服务器上建立私有的即时通讯服务&#xff0c;用户可以使用客户端连接至此服务器…

CAN网络出现错误帧从哪些方面去分析解决

标题&#xff1a;CAN网络出现错误帧从哪些方面去分析 实例1&#xff1a; 断电重启后&#xff0c;会有错误帧产生。 检查方案&#xff1a; 查看收发模块的初始化、使能是否在发送CAN报文之前完成&#xff1f; 实例2&#xff1a; 周期性报文&#xff0c;有时会冒出一帧错误帧&…

四则计算机实现(C++)(堆栈的应用)

算法要求&#xff1a; 输入一个数学表达式(假定表达式输入格式合法)&#xff0c;计算表达式结果并输出。数学表达式由单个数字和运算符“”、“-”、“*”、“/”、“(、) ”构成&#xff0c;例如 2 3 * ( 4 5 ) - 6 / 4。变量、输出采用整数&#xff0c;只舍不入。 图解算…

MySQL InnoDB Cluster

MySQL InnoDB Cluster 一、InnoDB Cluster 基本概述 MySQL InnoDB Cluster 为 MySQL 提供了一个完整的高可用解决方案。通过使用 MySQL Shell 提供的 AdminAPI,你可以轻松地配置和管理一组至少由3个MySQL服务器实例组成的 InnoDB 集群。 InnoDB 集群中的每个 MySQL 服务器实例…

linux无网络 无ip,显示网络未连接

标题:linux无网络 无ip&#xff0c;显示网络未连接 参考blog&#xff1a;Linux无网络连接问题排查 首先我们发现ens33没有ip地址&#xff0c;说明这个接口并没有被分到ip&#xff1b; 我们可以通过手动方式来给ens33获得网络ip sudo dhclient ens33&#xff0c;之后再输入ifc…

大数据Hadoop-HDFS_元数据持久化

大数据Hadoop-HDFS_元数据持久化 &#xff08;1&#xff09;在HDFS第一次格式化后&#xff0c;NameNode&#xff08;即图中的主NameNode&#xff09;就会生成fsimage和editslog两个文件&#xff1b; &#xff08;2&#xff09;备用NameNode&#xff08;即图中的备NameNode&…

ChatGPT Plus/GPT4高级数据分析和插件功能详解

ChatGPT 在论文写作与编程方面也具备强大的能力。无论是进行代码生成、错误调试还是解决编程难题&#xff0c;ChatGPT都能为您提供实用且高质量的建议和指导&#xff0c;提高编程效率和准确性。此外&#xff0c;ChatGPT是一位出色的合作伙伴&#xff0c;可以为您提供论文写作的…

微服务实战系列之Redis(cache)

前言 云淡天高&#xff0c;落木萧萧&#xff0c;一阵西北风掠过&#xff0c;似寒刀。冬天渐渐变得更名副其实了。“暖冬”的说法有点言过其实了。——碎碎念 微服务实战系列之Cache微服务实战系列之Nginx&#xff08;技巧篇&#xff09;微服务实战系列之Nginx微服务实战系列之F…

Structured Streaming: Apache Spark的流处理引擎

欢迎来到我们的技术博客&#xff01;今天&#xff0c;我们要探讨的主题是Apache Spark的一个核心组件——Structured Streaming。作为一个可扩展且容错的流处理引擎&#xff0c;Structured Streaming使得处理实时数据流变得更加高效和简便。 什么是Structured Streaming&#…

数据结构中的二分查找(折半查找)

二分法&#xff1a;顾名思义&#xff0c;把问题一分为2的处理&#xff0c;是一种常见的搜索算法&#xff0c;用于在有序数组或这有序列表中查找指定元素的位置&#xff0c;它的思想是将待搜索的区间不断二分&#xff0c;然后比较目标值与中间元素的大小关系&#xff0c;然后确定…