使用PyTorch构建神经网络,并计算参数Params

news2024/11/18 7:35:25

文章目录

    • 使用PyTorch构建神经网络,并计算参数Params
    • 举例
      • 计算具有全连接层的神经网络的参数数量
      • 计算卷积神经网络的参数数量
        • Params计算过程
    • 总结

使用PyTorch构建神经网络,并计算参数Params

在深度学习中,模型的参数数量是一个非常重要的指标,通常会影响模型的大小、训练速度和准确度等多个方面。在本教程中,我们将介绍如何计算深度学习模型的参数数量。

本教程将以PyTorch为例,展示如何计算一个包含卷积、池化、归一化和全连接等多种层的卷积神经网络的参数数量。具体来说,我们将首先介绍一个具有全连接层的神经网络的参数计算方法,然后扩展到包含卷积、池化、归一化和全连接等多种层的卷积神经网络。

举例

计算具有全连接层的神经网络的参数数量

假设我们有一个输入向量 x x x,其维度为 d i n d_{in} din,我们想将其映射到一个输出向量 y y y,其维度为 d o u t d_{out} dout。我们可以使用一个具有 n n n个隐藏层的全连接神经网络来完成这个映射,其中每个隐藏层具有 h h h个神经元。

在PyTorch中,我们可以通过定义一个继承自nn.Module的类来实现这个神经网络。下面是一个定义了一个具有两个隐藏层的全连接神经网络的示例代码:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, d_in, h, d_out, n):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(d_in, h)
        self.linear2 = nn.Linear(h, h)
        self.linear3 = nn.Linear(h, d_out)
        self.n = n

    def forward(self, x):
        h_relu = self.linear1(x).clamp(min=0)
        for i in range(self.n):
            h_relu = self.linear2(h_relu).clamp(min=0)
        y_pred = self.linear3(h_relu)
        return y_pred

其中,nn.Linear是PyTorch中的一个线性层,它将输入向量乘以一个权重矩阵,并加上一个偏置向量,得到输出向量。在这个例子中,我们定义了三个线性层,分别为self.linear1self.linear2self.linear3。在forward函数中,我们首先将输入向量x传递给self.linear1,然后通过ReLU非线性激活函数得到一个隐藏层输出h_relu。接下来,我们使用for循环多次将h_relu传递给self.linear2,再次使用ReLU非线性激活函数得到另一个隐藏层输出。最后,我们将最后一个隐藏层的输出传递给self.linear3,得到输出向量y_pred

现在让我们计算一下这个神经网络的参数数量。对于每个线性层,它都有一个权重矩阵和一个偏置向量,因此总的参数数量为:

参数数量 = d_in * h + h * h * (n-1) + h * d_out + h + d_out

其中,第一项 d i n ∗ h d_{in} * h dinh是输入层到第一个隐藏层的权重矩阵的参数数量;第二项 h ∗ h ∗ ( n − 1 ) h * h * (n-1) hh(n1)是每个隐藏层之间的权重矩阵的参数数量;第三项 h ∗ d o u t h * d_{out} hdout是最后一个隐藏层到输出层的权重矩阵的参数数量;第四项 h h h d o u t d_{out} dout分别是偏置向量的参数数量。

因此,这个具有两个隐藏层的全连接神经网络的参数数量取决于输入向量的维度 d i n d_{in} din,输出向量的维度 d o u t d_{out} dout,每个隐藏层的神经元数量 h h h和隐藏层数量 n n n

计算卷积神经网络的参数数量

现在让我们将上述方法扩展到卷积神经网络中。卷积神经网络是一种常用的深度学习模型,通常用于图像识别、自然语言处理等领域。它由多个卷积层、池化层、归一化层和全连接层等多种层组成。

为了计算卷积神经网络的参数数量,我们需要考虑每一层的参数数量。下面是一个简单的卷积神经网络的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from thop import profile

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = x.view(-1, 32 * 8 * 8)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

net = Net().to(device)
input_shape = (3, 224, 224)
summary(net, input_shape)

input_tensor = torch.randn(1, *input_shape).to(device)

flops, params = profile(net, inputs=(input_tensor,))
print('FLOPs: {:.2f}M'.format(flops / 1e6))

在这个示例中,我们定义了一个包含两个卷积层、两个池化层、两个归一化层和两个全连接层的卷积神经网络。我们使用nn.Conv2d定义了两个卷积层,使用nn.BatchNorm2d定义了两个归一化层,使用nn.MaxPool2d定义了两个池化层,使用nn.Linear定义了两个全连接层。

该网络结构输出如下:

在这里插入图片描述

Params计算过程

我们可以使用如下的方法计算这个卷积神经网络的参数数量:

  1. 对于每个卷积层,它有一个包含卷积核参数的权重张量和一个包含偏置参数的向量。因此,卷积层的参数数量为out_channels * (in_channels * kernel_size^2 + 1)。
  2. 对于每个归一化层,它有两个参数:缩放因子和偏移量。因此,归一化层的参数数量为2 * out_channels。
  3. 对于每个全连接层,它有一个包含权重参数的权重矩阵和一个包含偏置参数的向量。因此,全连接层的参数数量为(in_features + 1) * out_features。

根据上述公式,我们可以计算这个示例卷积神经网络的参数数量:

参数数量 = conv1参数数量 + bn1参数数量 + conv2参数数量 + bn2参数数量 + fc1参数数量 + fc2参数数量
         = 16 * (3 * 3^2 + 1) + 16 * 2 + 32 * (16 * 3^2 + 1) + 32 * 2 + (32 * 8 * 8 + 1) * 64 + (64 + 1) * 10
         = 136,970

因此,这个示例卷积神经网络的参数数量为136,970。

它计算了模型中各层的参数数量,包括卷积层、全连接层和BatchNorm层的参数数量。具体来说,公式计算了:

  • 第一层卷积层的参数数量:输入通道数为3,输出通道数为16,卷积核大小为3x3,因此共有16个卷积核,每个卷积核有3x3=9个参数,另外还有16个偏置参数,因此该层参数数量为16x(3x3+1)=448。
  • 第一层BatchNorm层的参数数量:该层有16个输出通道,每个通道有2个参数(缩放因子和偏置项),因此该层参数数量为16x2=32。
  • 第二层卷积层的参数数量:输入通道数为16,输出通道数为32,卷积核大小为3x3,因此共有32个卷积核,每个卷积核有16x3x3=144个参数,另外还有32个偏置参数,因此该层参数数量为32x(16x3x3+1)=4608。
  • 第二层BatchNorm层的参数数量:该层有32个输出通道,每个通道有2个参数,因此该层参数数量为32x2=64。
  • 第一个全连接层的参数数量:该层输入特征数为32x8x8=2048,输出特征数为64,因此该层参数数量为2048x64+64=131,136。
  • 第二个全连接层的参数数量:该层输入特征数为64,输出特征数为10,因此该层参数数量为64x10+10=650。

将上述各层的参数数量相加,即可得到模型的总参数数量。

另外,需要注意的是,参数数量和FLOPs是不同的概念。FLOPs是指在模型推理过程中,需要进行的浮点运算次数,而参数数量则是指模型中需要学习的参数的数量。在计算FLOPs时,需要考虑到每个卷积层、池化层和全连接层的输入输出形状,以及各层的卷积核大小、步长等参数信息。

总结

计算深度学习模型的参数数量是深度学习中非常基础的知识点,掌握好这一知识点有助于更好地理解和设计深度学习模型。

在本教程中,我们介绍了如何计算具有全连接层的神经网络和卷积神经网络的参数数量。对于具有全连接层的神经网络,我们可以使用简单的公式计算参数数量;对于卷积神经网络,我们需要考虑每一层的参数数量,并将它们相加得到总的参数数量。

需要注意的是,计算参数数量时需要注意每个层的超参数,例如卷积层的输入和输出通道数、卷积核大小等等。此外,某些特殊的层,如Dropout层或者BatchNorm层,可能需要特殊的计算方法。

在实际应用中,我们通常使用现有的深度学习框架(如PyTorch、TensorFlow等)来构建和训练深度学习模型,这些框架通常会自动计算模型的参数数量。但是,对于自己实现的模型或者需要手动调整模型参数的情况,了解计算参数数量的方法仍然非常有用。

希望本教程对您有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/503082.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

零基础学Java好找工作吗?好程序员告诉你入行Java有多惨?

为什么小源建议普通背景、零基础的大专生、本科生去学java编程呢? 因为真的香啊!小白容易上岸,而且工作3年的话,基本年薪就能到50w(只要你工作后也一直不断努力学习)。java岗位属于技术岗,没有任…

Android进阶:Activity的生命周期和启动模式

Activity的生命周期和启动模式 作为Android四大组件之中存在感最强的组件,Activity应该是我们在学习Android中第一个碰到的新概念。在日常开发过程中我们肯定会用到Activity,但是关于Activity的一些细节问题运行机制我们可能还有一些不清楚的问题。今天…

调用百度文心AI作画API实现中文-图像跨模态生成

作者介绍 乔冠华,女,西安工程大学电子信息学院,2020级硕士研究生,张宏伟人工智能课题组。 研究方向:机器视觉与人工智能。 电子邮件:1078914066qq.com 一.文心AI作画API介绍 1. 文心AI作画 文…

全开源ChatGPT聊天机器人商业版源码/支持魔改/完全开放源代码

🎈 限时活动领体验会员:可下载程序网创项目短视频素材 🎈 ☑️ 品牌:ChatGPT ☑️ 语言:PHP ☑️ 类型:ChatGPT ☑️ 支持:PCWAP 🎉 有需要的朋友记得关赞评,需要的底部获…

C++哈希

目录 一、认识哈希表 1.unordered_set和unordered_map 2.哈希表的概念 二、闭散列哈希表的实现 1.底层本质 (1)哈希表的存储结构 (2)元素的插入与查找 (3)哈希冲突 (4)负载…

深入浅出C++ ——线程库

文章目录 线程库thread类的简单介绍线程函数参数原子性操作库 mutex的种类std::mutexstd::recursive_mutexstd::timed_mutexstd::recursive_timed_mutex lock_guard与unique_locklock_guardunique_lock condition_variable 线程库 thread类的简单介绍 在C11之前,涉…

“广东省五一劳动奖章”获得者卫晓欣:“她”力量让新兴技术更获认可

近日,2023年广东省庆祝“五一”国际劳动节暨五一劳动奖表彰大会顺利召开,大会表彰了2023年全国和省五一劳动奖、工人先锋号代表。 其中,来自FISCO BCOS开源社区产业应用合作伙伴广电运通的创新中心总监卫晓欣,凭借在区块链领域的…

分布式锁Redisson对于(不可重入、不可重试、超时释放、主从一致性)四个问题的应对

文章目录 1 Redisson介绍2 Redisson快速入门3 Redisson可重入锁原理4 Redisson锁重试和WatchDog机制5 Redisson锁的MutiLock原理 基于setnx实现的分布式锁存在下面的问题: 重入问题:重入问题是指 获得锁的线程可以再次进入到相同的锁的代码块中&#xff…

Ai作图可控性演进——从SD到MJ

背景 Ai作图从Diffusion模型开始,作图进入稳步发展快车道。然后用过diffusion系列作图的同学对产图稳定性,以及可控性都会颇有微词。diffusion系列作图方法在宏观层面上确实能够比较好的做出看上去还不错的图。然后当你细抠细节时候,发现这东…

远程服务器搭建jupyter lab并在本地访问

1、安装jupyter pip install jupyter 可以直接在base环境下安装 2、配置jupyter 2.1 密钥生成 进入python交互模式,输入以下代码: from jupyter_server.auth import passwd passwd()然后输入密码,得到一串密钥,保存一下 2.2…

Java多线程入门到精通学习大全?了解几种线程池的基本原理、代码示例!(第五篇:线程池的学习)

本文介绍了Java中三种常用的线程池:FixedThreadPool、CachedThreadPool和ScheduledThreadPool,分别介绍了它们的原理、代码示例以及使用注意事项。FixedThreadPool适用于并发量固定的场景,CachedThreadPool适用于执行时间短的任务&#xff0c…

Linux C/C++后台开发面试重点知识

Linux C/C后台开发面试重点知识 文章转载自个人博客: Linux C/C后台开发面试重点知识 查看目录 一、C 面试重点 本篇主要是关于 C 语言本身,如果是整个后台技术栈的学习路线,可以看这篇文章: Linux C 后台开发学习路线 对于 C 后台开发面试来说&…

27岁转行学云计算值得吗?能就业不?

27岁转行学云计算值得吗?能就业不? 首先,云计算当然值得转行了,如此肯定的观点,应该没有人会反对吧,尤其是对IT行业的现状以及就业市场有所了解的人。如果你对这一点有所怀疑也很正常,只要通过各…

Spring Boot集成ShardingSphere分片利器 AutoTable (一)—— 简单体验 | Spring Cloud 45

一、背景 Sharding是 Apache ShardingSphere 的核心特性,也是 ShardingSphere 最被人们熟知的一项能力。在过去,用户若需要进行分库分表,一种典型的实施流程(不含数据迁移)如下: 用户需要准确的理解每一张…

详解快速排序的类型和优化

详解快速排序的优化 前言快排的多种写法霍尔法实现快排代码部分 挖坑法思路讲解代码部分 双指针法思路讲解代码部分 针对排序数类型的优化针对接近或已经有序数列和逆序数列三数取中代码实现 随机数 针对数字中重复度较高的数三路划分思路讲解代码部分 根据递归的特点进行优化插…

JSP招投标管理系统myeclipse开发mysql数据库WEB结构java编程

一、源码特点 JSP 招投标管理系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 JSP招投标管理系统myeclipse开发mysql数据库W 二、功能介绍 主要功能: …

BPMN2.0 任务-接收任务手动任务

接收任务 接收任务是一个简单的任务,它等待特定消息的到来。 当流程执行到接收任务时,流程状态将提交给持久性存储。流程将保持这种等待状态,直到流程引擎接收到特定的消息,这将触发接收任务之外流程的继续进行。 接收任务用左上角有一个消息图标的标准BPMN 2.0任务(圆角…

C++新特性总结

(智能指针,一些关键字,自动类型推导auto,右值引用移动语义完美转发,列表初始化,std::function & std::bind & lambda表达式使回调更方便,c11关于并发引入了好多好东西,有&am…

vivado工程转换到quartus下联合modelsim仿真

vivado用习惯了,现在快速换到quartus下仿真测试。写一个操作文档,以fpga实现pcm编码为例。 目录 一、建立工程 1、准备源码和仿真文件 2、新建工程 3、加载源文件 4、选择器件 5、仿真器配置 6、工程信息 二、配置工程 7、设置顶层文件 8、配置…

【多线程】初识线程,基础了解

目录 认识线程 概念 什么是线程? 为啥要有线程 进程和线程的区别 Java 的线程 和 操作系统线程 的关系 创建线程 1.继承 Thread 类 2.实现 Runnable 接口 3.通过匿名内部类方式创建Thread与实现Runnable 4.Lmabda表达式 Thread 类及常见方法 Thread 的常见构造方法…