PyTorch三种主流模型构建方式:nn.Sequential、nn.Module子类、nn.Module容器开发实践,以真实烟雾识别场景数据为例

news2025/4/17 0:52:08

Keras和PyTorch是两个常用的深度学习框架,它们都提供了用于构建和训练神经网络的高级API。

Keras:
Keras是一个高级神经网络API,可以在多个底层深度学习框架上运行,如TensorFlow和CNTK。以下是Keras的特点和优点:

优点:

简单易用:Keras具有简洁的API设计,易于上手和使用,适合快速原型设计和实验。
灵活性:Keras提供了高级API和模块化的架构,可以灵活地构建各种类型的神经网络模型。
复用性:Keras模型可以轻松保存和加载,可以方便地共享、部署和迁移模型。
社区支持:Keras拥有庞大的社区支持和活跃的开发者社区,提供了大量的文档、教程和示例代码。
缺点:

功能限制:相比于底层框架如TensorFlow和PyTorch,Keras在某些高级功能和自定义性方面可能有所限制。
可扩展性:虽然Keras提供了易于使用的API,但在需要大量定制化和扩展性的复杂模型上可能会有限制。
灵活程度:Keras主要设计用于简单的流程,当需要处理复杂的非标准任务时,使用Keras的灵活性较差。
适用场景:

初学者:对于新手来说,Keras是一个理想的选择,因为它简单易用,有丰富的文档和示例来帮助快速入门。
快速原型设计:Keras可以快速搭建和迭代模型,适用于快速原型设计和快速实验验证。
常规计算机视觉和自然语言处理任务:Keras提供了大量用于计算机视觉和自然语言处理的预训练模型和工具,适用于常规任务的开发与应用。
PyTorch:
PyTorch是一个动态图深度学习框架,强调易于使用和低延迟的调试功能。以下是PyTorch的特点和优点:

优点:

动态图:PyTorch使用动态图,使得模型构建和调试更加灵活和直观,可以实时查看和调试模型。
自由控制:相比于静态图框架,PyTorch能够更自由地控制模型的复杂逻辑和探索新的网络架构。
算法开发:PyTorch提供了丰富的数学运算库和自动求导功能,适用于算法研究和定制化模型开发。
社区支持:PyTorch拥有活跃的社区和大量的开源项目,提供了丰富的资源和支持。
缺点:

部署复杂性:相比于Keras等高级API框架,PyTorch需要开发者更多地处理模型的部署和生产环境的问题。
静态优化:相对于静态图框架,如TensorFlow,PyTorch无法进行静态图优化,可能在性能方面略逊一筹。
入门门槛:相比于Keras,PyTorch对初学者来说可能有一些陡峭的学习曲线。
适用场景:

研究和定制化模型:PyTorch适合进行研究和实验,以及需要灵活性和自由度较高的定制化模型开发。
高级计算机视觉和自然语言处理任务:PyTorch在计算机视觉和自然语言处理领域有广泛的应用,并且各类预训练模型和资源丰富。
在前面的两篇文章中整体系统总结记录了Keras和PyTroch这两大主流框架各自开发构建模型的三大主流方式,并对应给出来的基础的实例实现,感兴趣的话可以自行移步阅读即可:

《总结记录Keras开发构建神经网络模型的三种主流方式:序列模型、函数模型、子类模型》

《总结记录PyTorch构建神经网络模型的三种主流方式:nn.Sequential按层顺序构建模型、继承nn.Module基类构建自定义模型、继承nn.Module基类构建模型并辅助应用模型容器来封装》

本文的主要目的就是想要基于真实业务数据场景来实地开发实践这三种不同类型的模型构建方式,并对结果进行对比分析。

首先来看下数据集:

 首先来看序列模型构建实现:

def initModel():
    """
    nn.Sequential按层顺序构建模型
    """
    model = nn.Sequential()
    model.add_module("conv1", nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3))
    model.add_module("pool1", nn.MaxPool2d(kernel_size=2, stride=2))
    model.add_module("conv2", nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5))
    model.add_module("dropout", nn.Dropout2d(p=0.1))
    model.add_module("pool2", nn.AdaptiveMaxPool2d((1, 1)))
    model.add_module("flattens", nn.Flatten())
    model.add_module("linear1", nn.Linear(64, 32))
    model.add_module("relu1", nn.ReLU())
    model.add_module("linear2", nn.Linear(32, 1))
    return model

接下来是继承nn.Module基类构建自定义模型,如下所示:

class initModel(nn.Module):
    """
    继承nn.Module基类构建自定义模型
    """

    def __init__(self):
        super(initModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)
        self.dropout = nn.Dropout2d(p=0.1)
        self.pool2 = nn.AdaptiveMaxPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(32, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.dropout(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

最后是继承nn.Module基类并辅助应用模型容器进行封装构建方式,这里在前文中提到共有三种模型容器可用,分别是:

nn.Sequential
nn.ModuleList
nn.ModuleDict

代码实现如下所示:

class initModel(nn.Module):
    """
    继承nn.Module基类并辅助应用模型容器进行封装
    nn.Sequential作为模型容器
    """

    def __init__(self):
        super(initModel, self).__init__()
        self.model = nn.Sequential()
        self.model.add_module(
            "conv1", nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        )
        self.model.add_module("pool1", nn.MaxPool2d(kernel_size=2, stride=2))
        self.model.add_module(
            "conv2", nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)
        )
        self.model.add_module("dropout", nn.Dropout2d(p=0.1))
        self.model.add_module("pool2", nn.AdaptiveMaxPool2d((1, 1)))
        self.model.add_module("flatten", nn.Flatten())
        self.model.add_module("linear1", nn.Linear(64, 32))
        self.model.add_module("relu", nn.ReLU())
        self.model.add_module("linear2", nn.Linear(32, 1))

    def forward(self, x):
        y = self.model(x)
        return y


class initModel(nn.Module):
    """
    继承nn.Module基类并辅助应用模型容器进行封装
    nn.ModuleList作为模型容器
    """

    def __init__(self):
        super(initModel, self).__init__()
        self.layers = nn.ModuleList(
            [
                nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
                nn.Dropout2d(p=0.1),
                nn.AdaptiveMaxPool2d((1, 1)),
                nn.Flatten(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
            ]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class initModel(nn.Module):
    """
    继承nn.Module基类并辅助应用模型容器进行封装
    nn.ModuleDict作为模型容器
    """

    def __init__(self):
        super(initModel, self).__init__()
        self.layers_dict = nn.ModuleDict(
            {
                "conv1": nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
                "pool1": nn.MaxPool2d(kernel_size=2, stride=2),
                "conv2": nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
                "dropout": nn.Dropout2d(p=0.1),
                "pool2": nn.AdaptiveMaxPool2d((1, 1)),
                "flatten": nn.Flatten(),
                "linear1": nn.Linear(64, 32),
                "relu": nn.ReLU(),
                "linear2": nn.Linear(32, 1),
            }
        )

    def forward(self, x):
        layers = [
            "conv1",
            "pool1",
            "conv2",
            "dropout",
            "pool2",
            "flatten",
            "linear1",
            "relu",
            "linear2",
        ]
        for layer in layers:
            x = self.layers_dict[layer](x)
        return x

跟keras框架一样,默认都是设定100次epoch的迭代计算,这里直接来看结果图:

 感兴趣都可以自行实践一下,很多内容或者是方法本质上都是触类旁通的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/914993.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决git上传远程仓库时的最大文件大小限制

git默认限制最大的单文件100M,当某个文件到达50M时会给你提示。解决办法如下 首先,打开终端,进入项目所在的文件夹; 输入命令:git config http.postBuffer 524288000 执行完上面的语句后输入:git config…

Stable Diffusion 系列教程 | 图生图基础

前段时间有一个风靡全网的真人转漫画风格,受到了大家的喜欢 而在SD里,就可以通过图生图来实现类似的效果 当然图生图还有更好玩的应用,我们一点一点来探索 首先我们来简单进行一下图生图的这一个实践---真人转动漫 1. 图生图基本界面 和…

代码之美:探索可维护性的核心与实践

为什么可维护性如此重要 项目的长期健康 在软件开发的早期阶段,团队可能会对代码的可维护性不太重视,因为他们更关心的是功能的快速交付。但随着时间的推移,随着代码库的增长和复杂性的增加,不重视代码的可维护性可能会导致严重的…

docker使用安装教程

docker使用安装教程 一、docker安装及下载二、使用教程2.1 镜像2.2 容器2.3 docker安装Redis 一、docker安装及下载 一、安装 安装执行命令:curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun 二、启停常用命令 启动docker,执行命令&#xf…

分支和循环语句-C语言(初阶)

目录 一、什么是语句 二、分支语句 2.1 if语句 2.2 switch语句 三、循环语句 3.1 while循环 3.2 for循环 3.3 do...while循环 一、什么是语句 C语言语句有五类:表达式语句、函数调用语句、控制语句、复合语句、空语句。 控制语句用于控制程序的执行流程&#xff0…

在vue3+ts+vite中使用svg图片

目录 前言 步骤 1.安装svg-sprite-loader,这里使用的是6.0.11版本 2.项目的svg图片存放在src/icons下,我们在这里创建两个文件index.ts和index.vue(在哪创建和文件名字并没有任何要求) 3.在index.ts中加入下列代码(如果报错找不到fs模块请…

Redis的基本操作

文章目录 1.Redis简介2.Redis的常用数据类型3.Redis的常用命令1.字符串操作命令2.哈希操作命令3.列表操作命令4.集合操作命令5.有序集合操作命令6.通用操作命令 4.Springboot配置Redis1.导入SpringDataRedis的Maven坐标2.配置Redis的数据源3.编写配置类,创还能Redis…

ubuntu修改默认文件权限umask

最近在使用ubuntu的过程中发现一个问题: 环境是AWS EC2,登录用户ubuntu,系统默认的umask是027,修改/etc/profile文件中umask 027为022后,发现从ubuntu用户sudo su过去root用户登录查询到的umask还是027,而…

2023-8-22 单调栈

题目链接&#xff1a;单调栈 #include <iostream>using namespace std;const int N 100010;int n; int stk[N], tt;int main() {cin >> n;for(int i 0; i < n; i ){int x;cin >> x;while(tt && stk[tt] > x) tt--;if(tt) cout << st…

第十章,搜索模块

10.1添加搜索框 <template><div class="navbar-form navbar-left hidden-sm"><div class="form-group"><inputv-model.trim="value"type="text"class="form-control search-input mac-style"placeho…

数据传输过程

2 数据传输过程 了解网络中常用的分层模型后&#xff0c;现在来学习一下数据在各层之间是如何传输的。 2.1数据封装与解封装过程(一) 下面我们将以TCP/IP五层结构为基础来学习数据在网络中传输的“真相”。由于这个过程比较 抽象&#xff0c;我们可以类比给远在美国的朋友邮寄…

人工智能深度估计技术

人工智障&#xff08;能&#xff09;走起&#xff01;&#xff01;&#xff01; 下面是基本操作&#xff1a; 在Hugging Face网页中找到Depth Estimation的model&#xff0c;如下图&#xff1a; Hugging Face – The AI community building the future. &#xff08;上Huggin…

从自动驾驶到智能助理:AI和ML技术的革命性应用与前景

人工智能&#xff08;AI&#xff09;和机器学习&#xff08;ML&#xff09;的快速发展正在改变我们的世界。它们以惊人的速度渗透到各个领域&#xff0c;从自动驾驶汽车到智能助理、语音识别和自然语言处理等。AI和ML技术的应用范围和影响力越来越广泛&#xff0c;为我们的日常…

SpringMVC拦截器学习笔记

SpringMVC拦截器 拦截器知识 拦截器(Interceptor)用于对URL请求进行前置/后置过滤 Interceptor与Filter用途相似但实现方式不同 Interceptor底层就是基于Spring AOP面向切面编程实现 拦截器开发流程 Maven添加依赖包servlet-api <dependency><groupId>javax.se…

【Rust】Rust学习 第十八章模式用来匹配值的结构

模式是 Rust 中特殊的语法&#xff0c;它用来匹配类型中的结构&#xff0c;无论类型是简单还是复杂。结合使用模式和 match 表达式以及其他结构可以提供更多对程序控制流的支配权。模式由如下一些内容组合而成&#xff1a; 字面值解构的数组、枚举、结构体或者元组变量通配符占…

CSS笔记

介绍 CSS导入方式 三种方法都将文字设置成了红色 CSS选择器 元素选择器 id选择器 图中div将颜色控制为红色&#xff0c;#name将颜色控制为蓝色&#xff0c;谁控制的范围最小&#xff0c;谁就生效&#xff0c;所以第二个div是蓝色的。id属性值要唯一&#xff0c;否则报错。 clas…

【STM32RT-Thread零基础入门】 6. 线程创建应用(线程挂起与恢复)

硬件&#xff1a;STM32F103ZET6、ST-LINK、usb转串口工具、4个LED灯、1个蜂鸣器、4个1k电阻、2个按键、面包板、杜邦线 文章目录 前言一、RT-Thread相关接口函数1. 挂起线程2. 恢复线程 二、程序设计1. car_led.c2.car_led.h3. main.c 三、程序测试总结 前言 在上一个任务中&a…

Mysql group by使用示例

文章目录 1. groupby时不能查询*2. 查询出的列必须在group by的条件列中3. group by多个字段&#xff0c;这些字段都有索引也会索引失效&#xff0c;只有group by单个字段索引才能起作用4. having条件必须跟group by相关联5. 用group by做去重6. 使用聚合函数做数量统计7. havi…

ShardingSphere02-MySQL主从同步配置

1、MySQL主从同步原理 基本原理&#xff1a; slave会从master读取binlog来进行数据同步 具体步骤&#xff1a; step1&#xff1a;master将数据改变记录到二进制日志&#xff08;binary log&#xff09;中。step2&#xff1a; 当slave上执行 start slave 命令之后&#xff0c…

mysql------做主从复制,读写分离

1.为什么要做主从复制&#xff08;主从复制的作用&#xff09; 做数据的热备&#xff0c;作为后备数据库&#xff0c;主数据库服务器故障后&#xff0c;可切换到从数据库继续工作&#xff0c;避免数据丢失。 架构的扩展。业务量越来越大,I/O访问频率过高&#xff0c;单机无法满…