parameters()函数 --- 获取模型参数量

news2024/9/24 7:16:43

  parameters() 函数是 PyTorch 中 torch.nn.Module 类的一个方法,用于返回模型中所有可训练的参数。下面是对这个函数的详细解释:

1. parameters() 方法工作机制

parameters() 方法工作机制:定义一个模型,通常会将多个层(如卷积层、线性层等)组合在一起,这些层就是主模块的子模块。而这些子模块中有些也可能包含自己的子模块,形成一个递归的层次结构。parameters() 方法会自动遍历整个层次结构,获取每个模块和子模块中的可训练参数。

2. 返回值

  • 返回一个迭代器,其中包含所有可训练的参数。

 3. 示例

import torch
import torch.nn as nn

# 定义一个包含多个子模块的模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 子模块1
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # 子模块2
        self.fc = nn.Linear(32 * 16 * 16, 10)  # 子模块3

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # Flattening for the fully connected layer
        x = self.fc(x)
        return x

# 实例化模型
model = MyModel()

# 获取模型的参数
for param in model.parameters():
    print(param.size())

      上述代码中Conv2d函数bias没有指定值,则取默认值bias=True,也就是说,在上述代码中,bias是存在的。

分析:

  • MyModel 包含了 3 个层:两个卷积层(conv1conv2)以及一个全连接层(fc)。
  • 这些层都是 MyModel子模块
  • 当你调用 model.parameters() 时,它不仅返回 MyModel 自己的参数,还会递归地返回所有子模块(即 conv1conv2fc)的参数。

3. 递归参数获取

parameters() 函数递归地遍历模块中的所有子模块,获取每个模块的参数。每个 nn.Module 对象的 parameters() 方法会遍历它自己和它的所有子模块,将所有可训练参数打包在一起。

你可以通过以下代码验证这一点:

# 打印每个模块的参数名和大小 
for name, param in model.named_parameters(): 
    print(f"参数名: {name}, 大小: {param.size()}")

输出:

参数名: conv1.weight, 大小: torch.Size([16, 3, 3, 3]) 
参数名: conv1.bias, 大小: torch.Size([16]) 
参数名: conv2.weight, 大小: torch.Size([32, 16, 3, 3]) 
参数名: conv2.bias, 大小: torch.Size([32]) 
参数名: fc.weight, 大小: torch.Size([10, 8192]) 
参数名: fc.bias, 大小: torch.Size([10])

这里你会看到,不仅 MyModel 中的卷积层 conv1conv2 的权重和偏置参数被列出,连全连接层 fc 的参数也被列出了。

4. 参数过滤(过滤出可训练的参数)

可以通过条件过滤来获取特定类型的参数,例如仅获取可训练的参数:

trainable_params = [p for p in model.parameters() if p.requires_grad]

5. 计算参数总量

可以结合 numel() 方法来计算模型的参数总量:

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 
print(f"总参数量: {total_params}")

在这个示例中,p.numel() 返回参数的元素数量,if p.requires_grad 确保只计算需要梯度的参数(即可训练的参数)。运行这个代码将输出模型的参数总量。 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2159772.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Oauth2整合gateway网关实现微服务单点登录】

文章目录 一.什么是单点登录?二.Oauth2整合网关实现微服务单点登录三.时序图四.代码实现思路1.基于OAuth2独立一个认证中心服务出来2.网关微服务3产品微服务4.订单微服务5.开始测试单点登录 一.什么是单点登录? 单点登录(Single Sign On&…

sql语法学习:关键点和详细解释

学习SQL语法是掌握数据库操作的基础。以下是SQL语法的一些关键点和详细解释: 1. SQL基础 SQL(Structured Query Language)是一种用于管理和操作关系型数据库的标准语言。它主要包括以下几个部分: 数据定义语言(DDL&…

全栈开发(五):初始化前端项目(nuxt3+vue3+element-plus)+前端代理

1.初始化前端项目 Nuxt3:搭建项目_nuxt3 项目搭建-CSDN博客、 2.配置代理 nuxt.config.ts // https://nuxt.com/docs/api/configuration/nuxt-configexport default defineNuxtConfig({devtools: { enabled: true },modules: ["element-plus/nuxt", "pinia/n…

智能PPT行业赋能用户画像

智能PPT市场在巨大的需求前景下,已吸引一批不同类型的玩家投入参与竞争。从参与玩家类型来看,不乏各类与PPT创作有关的上下游企业逐步向智能PPT赛道转型进入,也包括顺应生成式AI技术热潮所推出的创业企业玩家。当前,智能PPT赛道发…

在虚幻引擎中创建毛发/头发

在虚幻引擎中创建毛发/头发 , 首先开启两个插件 Groom 和 Alembic Groom Importer 打开蒙皮缓存 导出人物模型 将人物导入Blender , 选择需要种植头发的点 指定并选择 点击毛发 这里变成爆炸头了 , 把数量和长度调一下 切换到梳子模式 调整发型 导出为abc , 文件路径不…

基于opencv的车牌检测和识别系统(代码+教程)

车牌检测与识别技术详解 车牌检测和识别(License Plate Recognition, LPR)是一项重要的计算机视觉任务,它在交通管理、安全监控以及智能门禁系统等多个领域都有着广泛的应用。随着深度学习技术的发展,LPR系统的准确性和鲁棒性得到…

【算法业务】基于Multi-Armed Bandits的个性化push文案自动优选算法实践

1. 背景介绍 该工作属于多年之前的用户增长算法业务项目。在个性化push中,文案扮演非常重要的角色,是用户与push的商品之间的桥梁,文案是用户最直接能感知的信息。应该说在push产品信息之外,最重要的就是文案,直接能…

机器学习 | Scikit Learn中的普通最小二乘法和岭回归

在统计建模中,普通最小二乘法(OLS)和岭回归是两种广泛使用的线性回归分析技术。OLS是一种传统的方法,它通过最小化预测值和实际值之间的平方误差之和来找到数据的最佳拟合线。然而,OLS可以遭受高方差和过拟合时&#x…

Unreal Engine 5 C++: 插件编写03 | MessageDialog

在虚幻引擎编辑器中编写Warning弹窗 准备工作 FMessageDialog These functions open a message dialog and display the specified informations there. EAppReturnType::Type 是 Unreal Engine 中用于表示应用程序对话框(如消息对话框)返回结果的枚举…

vue.js 展示树状结构数据,动态生成 HTML 内容

展示树状结构数据: 从 jsonData 读取树状结构的 JSON 数据,将其解析并生成 HTML 列表来展示。树状结构数据根据 id 和 label 属性组织,节点可以包含子节点 children。 展示评级信息: 从预定义的表单字段 form 中读取 arRateFlag 和…

GS-SLAM论文阅读笔记--GLC-SLAM

前言 最近GS-SLAM回环检测的工作已经逐步发展了,看一下这篇新文章。 文章目录 前言1.背景介绍2.关键内容2.1 tracking2.2 local mapping2.3 Loop Closing2.4总体流程 3.文章贡献 1.背景介绍 现有的基于3dgs的SLAM方法往往存在累积的跟踪误差和地图漂移&#xff0c…

三菱FX5U CPU模块的初始化“(格式化PLC)”

1、连接FX5U PLC 1、使用以太网电缆连接计算机与CPU模块。 2、从工程工具的菜单选择[在线]中[当前连接目标]。 3、在“简易连接目标设置 Connection”画面中,在与CPU模块的直接连接方法中选择[以太网]。点击[通信测试]按钮,确认能否与CPU模块连接。 FX5…

小柴冲刺软考中级嵌入式系统设计师系列二、嵌入式系统硬件基础知识(1)数字电路基础

目录 一、信号特征 二、组合逻辑电路和时序逻辑电路 1、组合逻辑电路 2、时序逻辑线路 三、信号转换 1、数字集成电路的分类 2、常用电平接口技术 四、可编程逻辑器件 flechazohttps://www.zhihu.com/people/jiu_sheng 小柴冲刺嵌入式系统设计师系列总目录https://blo…

[vulnhub] Prime 1

https://www.vulnhub.com/entry/prime-1,358/ 主机发现端口扫描 探测存活主机,137是靶机 nmap -sP 192.168.75.0/24 // Starting Nmap 7.93 ( https://nmap.org ) at 2024-09-22 16:25 CST Nmap scan report for 192.168.75.1 Host is up (…

Rust - 字符串:str 与 String

在其他语言中,字符串通常都会比较简单,例如 “hello, world” 就是字符串章节的几乎全部内容了。 但是Rust中的字符串与其他语言有所不同,若带着其他语言的习惯来学习Rust字符串,将会波折不断。 所以最好先忘记脑中已有的关于字…

MMD模型一键完美导入UE5-VRM4U插件方案(一)

1、下载pmx模型 1、去模之屋官网下载MMD模型,模之屋 2、下载完成得到pmx和Texture文件 2、下载并启用VRM4U插件 1、下载VRM4U插件, VRM4U,点击Latest下载对应引擎版本 2、将插件放到Plugins目录,然后

GB28181语音对讲协议详解

GB28181-2016语音对讲流程如下图1所示: 图1.语音对讲流程。 其中, 信令 1 、2 、 3 、 4 为语音广播通知、 语音广播应答消息流程; 信令 5 、 1 2 、 1 3 、 1 4 、 1 5 、 1 6 为 S I P 服务器接收到客户端的呼叫请求通过 B 2 B UA 代理方式建立语音流接收者与媒…

DevExpress WPF中文教程:如何解决行焦点、选择的常见问题?

DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

【HarmonyOS】应用权限原理和封装

背景 在项目中,避免不了需要调用系统资源和系统能力,比如:日历读写、摄像头等。因此,需要了解对系统资源访问权限的申请方式方法。 授权方式 包括两种授权方式,分别是system_grant(系统授权) 和 user_grant(用户授权)…

ruoyi源码解析学习 - 微服务版 - ruoyi-gateway

com.ruoyi.gateway 今天简单看看若依的gateway的配置模块干了啥 最近面试很多外包公司,都对低代码平台有点要求,这些代码虽说用起来不费劲,但是其中还是有很多细节能让我学习学习的。(微服务版,上次搞jeecgboot的笔试…