pytorch04:网络模型创建

news2024/11/16 8:24:44

目录

  • 一、模型创建过程
    • 1.1 以LeNet网络为例
    • 1.2 LeNet结构
    • 1.3 nn.Module
  • 二、网络层容器(Containers)
    • 2.1 nn.Sequential
      • 2.1.1 常规方法实现
      • 2.1.2 OrderedDict方法实现
    • 2.2 nn.ModuleList
    • 2.3 nn.ModuleDict
    • 2.4 三种容器构建总结
  • 三、AlexNet网络构建

一、模型创建过程

在这里插入图片描述

1.1 以LeNet网络为例

在这里插入图片描述

网络代码如下:

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()  # 调用父类方法,作用是调用nn.Module类的构造函数,
        # 确保LeNet类被正确地初始化,并继承了nn.Module 的所有属性和方法
        self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

1.2 LeNet结构

在这里插入图片描述

LeNet:conv1–>pool1–>conv2–>pool2–>fc1–>fc2–>fc3
在这里插入图片描述

1.3 nn.Module

Module是nn模块中的功能,nn模块还有Parameter、functional等模块。
在这里插入图片描述
nn.Module主要有以下参数:
• parameters : 存储管理nn.Parameter类
• modules : 存储管理nn.Module类
• buffers:存储管理缓冲属性,如BN层中的running_mean

二、网络层容器(Containers)

在这里插入图片描述

2.1 nn.Sequential

nn.Sequential 是 nn.module的容器,也是最常用的容器,用于按顺序包装一组网络层
• 顺序性:各网络层之间严格按照顺序构建
• 自带forward():自带的forward里,通过for循环依次执行前向传播运算

2.1.1 常规方法实现

LeNet网络由两部分构成,中间的卷积池化特征提取部分(features),以及最后的分类部分(classifier)。
在这里插入图片描述
具体代码如下:

class LeNetSequential(nn.Module):
    def __init__(self, classes):
        super(LeNetSequential, self).__init__()
        self.features = nn.Sequential(  #特征提取部分
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),)

        self.classifier = nn.Sequential(  #分类部分
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes),)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

打印网络层:
在这里插入图片描述

2.1.2 OrderedDict方法实现

使用有序字典的方法构建Sequential
代码如下:

class LeNetSequentialOrderDict(nn.Module):
    def __init__(self, classes):
        super(LeNetSequentialOrderDict, self).__init__()

        self.features = nn.Sequential(OrderedDict({
            'conv1': nn.Conv2d(3, 6, 5),
            'relu1': nn.ReLU(inplace=True),
            'pool1': nn.MaxPool2d(kernel_size=2, stride=2),

            'conv2': nn.Conv2d(6, 16, 5),
            'relu2': nn.ReLU(inplace=True),
            'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
        }))

        self.classifier = nn.Sequential(OrderedDict({
            'fc1': nn.Linear(16 * 5 * 5, 120),
            'relu3': nn.ReLU(),

            'fc2': nn.Linear(120, 84),
            'relu4': nn.ReLU(inplace=True),

            'fc3': nn.Linear(84, classes),
        }))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

先看一下Sequential函数中init初始化的两种方法,当我们使用OrderedDict方法时,会进行判断,使用self.add_module(key, module)方法将字典中的key和value取出来添加到Sequential中。

class Sequential(Module):
    def __init__(self, *args):
        super().__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

通过这种方法构建可以给每一网络层添加一个名称,网络输出结果如下:
在这里插入图片描述

2.2 nn.ModuleList

nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层
主要方法:
• append():在ModuleList后面添加网络层
• extend():拼接两个ModuleList
• insert():指定在ModuleList中位置插入网络层

使用列表生成式,通过一行代码就能构建20个网络层。
代码演示:

class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        # 使用列表生成式构建20个全连接层,每个全连接层10个神经元的网络
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
        return x


net = ModuleList()

2.3 nn.ModuleDict

nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层,可以用过参数的形式选取想要调用的网络层。
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除

代码展示,只选取conv和relu两个网络层:

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })

        # 激活函数
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):  # 传入两个参数 用来选择网络层
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')  #只选取conv和relu两个网络层。
print(output)

2.4 三种容器构建总结

• nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
• nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
• nn.ModuleDict:索引性,常用于可选择的网络层

三、AlexNet网络构建

AlexNet:2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代
AlexNet特点如下:

  1. 采用ReLU:替换饱和激活函数,减轻梯度消失
  2. 采用LRN(Local Response Normalization):对数据归一化,减轻梯度消失
  3. Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
  4. Data Augmentation:TenCrop,色彩修改

网络结构图如下:
在这里插入图片描述
构建代码:

import torch.nn as nn
import torch
from torchsummary import summary
# 定义一个名为AlexNet的神经网络模型,继承自nn.Module基类
class AlexNet(nn.Module):
    # 构造函数,初始化网络的参数
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
        # 调用父类的构造函数
        super().__init__()
        # 定义神经网络的特征提取部分,包含多个卷积层和池化层
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),  # 输入通道3,输出通道64,卷积核大小11x11,步长4,填充2
            nn.ReLU(inplace=True),  # 使用ReLU激活函数,inplace=True表示原地操作,节省内存
            nn.MaxPool2d(kernel_size=3, stride=2),  # 最大池化层,核大小3x3,步长2
            nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 输入通道64,输出通道192,卷积核大小5x5,填充2
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        # 定义自适应平均池化层,将输入的任意大小的特征图池化为固定大小6x6
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        # 定义分类器部分,包含全连接层和Dropout层
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),  # 使用Dropout进行正则化,随机丢弃一部分神经元以防止过拟合
            nn.Linear(256 * 6 * 6, 4096),  # 输入大小为256*6*6,输出大小为4096
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),  # 最后的全连接层输出类别数
        )

    # 前向传播函数,定义数据在网络中的传播过程
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)  # 特征提取
        x = self.avgpool(x)  # 平均池化
        x = torch.flatten(x, 1)  # 将特征图展平成一维向量
        x = self.classifier(x)  # 分类器
        return x
 if __name__ == '__main__':
    net = AlexNet().cuda()
    summary(net, (3, 256, 256))

打印出的网络结构图如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1352885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【力扣题解】P501-二叉搜索树中的众数-Java题解

👨‍💻博客主页:花无缺 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 花无缺 原创 收录于专栏 【力扣题解】 文章目录 【力扣题解】P501-二叉搜索树中的众数-Java题解🌏题目描述💡题解&#x1f…

文本表示模型简介

文本是一类非常重要的非结构化数据,如何表示文本数据一直是机器学习领域的一个重要研究方向。那么有哪些文本表示模型?以及它们各有什么优缺点? 1. 常见的文本表示模型 1.1. 词袋模型和N-gram模型 最基础的文本表示模型是词袋模型。顾名思…

Java最新技术介绍和分析 (202305)

说明:本文完成了2023年5月份,当时最新的LTS版本是Java17,本文在撰写时参考了美团技术团队和阿里JDK团队相关的文章,以及本文也引了用文章中的图片。在此表示感谢! Java版本火车 相信老牌的Java开发者和爱好者把Java的…

计算机创新协会冬令营——暴力枚举题目01

首先是欢迎大家参加此次的冬令营,我们协会欢迎所有志同道合的同学们。话不多说,先来看看今天的题目吧。 题目 力扣题号:2351. 第一个出现两次的字母 注:下述题目和示例均来自力扣 题目 给你一个由小写英文字母组成的字符串 s &…

(ros2)控制gazebo移动的话题:

gazebo并不是ros2内自带的,是一个独立的软件,需要安装ros_gz功能被把ros2消息转换为gazebo可以识别的命令,并且把gazebo状态转换为ros2信息,所以可以认为ros_gz_bridge节点就是gazebo, 这个节点只接收2个话题&#xf…

Yarn的安装与使用详细介绍

什么是yarn Apache Hadoop YARN (Yet Another Resource Negotiator,另一种资源协调者)是一种新的 Hadoop 资源管理器,它是一个通用资源管理系统,可为上层应用提供统一的资源管理和调度,它的引入为集群在利用…

2024年天津仁爱学院专升本专业课考试考场安排及准考证打印的通知

天津仁爱学院2024年高职升本科专业课考试通知 一、考试科目及时间 天津仁爱学院2024年高职升本科专业课考试定于2024年1月9日9:00-14:00举行。考试地点:天津市静海区团泊新城博学苑,天津仁爱学院。具体考试科目时间安排如下表&a…

提高工作效率的Postman环境变量使用方法

在 Postman 中,用 Environments 来管理环境变量。我们在开发的过程中,往往会用到多个环境:开发环境,测试环境,UAT 环境,生产环境等。我们要调用不同环境的 API 时,只需切换 Postman 的 Environm…

ISP 基础知识积累

Amber:现有工作必要的技术补充,认识需要不断深入,这个文档后续还会增加内容进行完善。 镜头成像资料 ——干货满满,看懂了这四篇文章,下面的问题基本都能解答 看完思考 1、ISP 是什么,有什么作用&#xff…

php安装扩展event 提示 No package ‘openssl‘ found 解决方法

在使用pecl编译安装最新版event模块的时候提示 No package openssl found , 可是本机是安装了openssl的, 编译时找不到, 大概率就是环境配置的问题了, 增加 OPENSSL_CFLAGS OPENSSL_LIBS环境变量即可解决. 异常提示信息: checking for openssl > 1.0.2... no configure: …

【数据库原理】(5)关系数据库的关系数据结构

关系及相关概念 在关系模型中,无论是实体还是实体之间的联系均由关系(二维表)来表示。 1.域(Domain) 定义:域是一组具有相同数据类型的值的集合。例子:实数集合、整数集合、英文字母集合等。 2.笛卡儿积(Cartesian…

学习Vue单文件组件总结

今天主要学习了组件实例对象的一个重要内置关系和单文件组件。先说一下实例对象的内置关系,在这里要对JS中的原型链有一定的基础,Vue构造函数的prototype原型指向的是Vue的原型对象,new出来的Vue实例对__proto__同样指向的是Vue的原型对象&am…

Hex 文件类型字段详解

文章目录 地址类型字段详解02(扩展段地址记录):指定后续数据记录的向左移动4位后开始计算地址04(扩展线性地址记录):指定后续数据记录的起始地址的高16位 地址类型字段详解 02(扩展段地址记录&…

18、BLIP

简介 github BLIP提出了一种基于预训练的方法,通过联合训练视觉和语言模型来提升多模态任务的性能。 BLIP(Bootstrapping Language-Image Pretraining)是salesforce在2022年提出的多模态框架,是理解和生成的统一,引入了跨模态的编码器和解码…

RocketMQ5.0延时消息时间轮算法

前言 RocketMQ 相较于其它消息队列产品的一个特性是支持延时消息,也就是说消息发送到 Broker 不会立马投递给消费者,要等待一个指定的延迟时间再投递,适用场景例如:下单后多长时间没付款系统自动关闭订单。 RocketMQ 4.x 版本的延…

【LeetCode-剑指offer】-- 9.乘积小于K的子数组

9.乘积小于K的子数组 方法:滑动窗口 关于为什么子数组数目为j-11。这时候就要理解采用滑动窗口的思路其实是枚举子数组的右端点,然后来找到满足条件的最小左端点。也即当得到满足条件的窗口时,就意味着得到了以 j 作为右端点时满足条件的左端…

框架的灵魂之笔-反射

反射:框架的灵魂 类加载器 概述:当程序要使用某个类的时候,如果该类还未被加载到内存中,则系统会通过以下三个步骤 ①类的加载 ②类的连接 ③类的初始化来对类进行初始化。如果不出现意外情况,JVM将会连续完成这三个步…

momentjs计算两个时间差返回时分秒

// 导入 Moment.js 模块 const moment require(moment);// 定义起始时间和结束时间 const startTime 2021-09-30T14:30:00; // 格式必须符合 ISO8601(YYYY-MM-DDTHH:mm:ss) const endTime 2021-09-30T15:45:00;// 创建 Moment.js 对象来表示起始时间和…

CentOS 7 实战指南:文本处理命令详解

前言 在Linux系统中,文本处理是非常基础却又必不可少的一项技能。如果你正在使用CentOS系统,那么学会如何利用文本操作命令来高效地处理文本文件无疑将会是一个强有力的工具。 本篇文章将介绍一些最常用和最实用的文本操作命令,并通过详尽的…

如何自动生成 API 接口文档 - 一份详细指南

本篇文章详细教你如何使用 Apifox 的 IDEA 插件实现自动生成接口代码。好处简单总结有以下几点: 自动生成接口文档: 不用手写,一键点击就可以自动生成文档,当有更新时,点击一下就可以自动同步接口文档;代码…