如何得到深度学习模型的参数量和计算复杂度

news2025/1/6 20:26:59

1.准备好网络模型代码

import torch
import torch.nn as nn
import torch.optim as optim

# BP_36: 输入2个节点,中间层36个节点,输出25个节点
class BP_36(nn.Module):
    def __init__(self):
        super(BP_36, self).__init__()
        self.fc1 = nn.Linear(2, 36)  # 输入2个节点,中间层36个节点
        self.fc2 = nn.Linear(36, 25)  # 输出25个节点

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数
        x = self.fc2(x)
        return x

# BP_64: 输入2个节点,中间层64个节点,输出25个节点
class BP_64(nn.Module):
    def __init__(self):
        super(BP_64, self).__init__()
        self.fc1 = nn.Linear(2, 64)  # 输入2个节点,中间层64个节点
        self.fc2 = nn.Linear(64, 25)  # 输出25个节点

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数
        x = self.fc2(x)
        return x

# Bi-LSTM: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_LSTM(nn.Module):
    def __init__(self):
        super(Bi_LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向LSTM
        self.fc1 = nn.Linear(72, 25)  # LSTM的输出72维,经过线性层后输出25个节点

    def forward(self, x):
        # x的形状应该是(batch_size, seq_len, input_size)
        x, _ = self.lstm(x)  # 输出LSTM的结果
        x = self.fc1(x)
        return x

# Bi-GRU: 输入2个节点,中间层36个节点,线性层输入72个节点,输出25个节点
class Bi_GRU(nn.Module):
    def __init__(self):
        super(Bi_GRU, self).__init__()
        self.gru = nn.GRU(input_size=2, hidden_size=36, bidirectional=True, batch_first=True)  # 双向GRU
        self.fc1 = nn.Linear(72, 25)  # GRU的输出72维,经过线性层后输出25个节点

    def forward(self, x):
        # x的形状应该是(batch_size, seq_len, input_size)
        x, _ = self.gru(x)  # 输出GRU的结果
        x = self.fc1(x)
        return x

2.运行计算参数量和复杂度的脚本

import torch
# from net import BP_36
# from net import BP_64
# from net import Bi_LSTM
from net import Bi_GRU

from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# 统计Transformer模型的参数量和计算复杂度
model_transformer = Bi_GRU()
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (256,2), as_strings=True, print_per_layer_stat=False)
print('模型参数量:' + params_transformer)
print('模型计算复杂度:' + flops_transformer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2271107.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PHP框架+gatewayworker实现在线1对1聊天--发送消息(6)

文章目录 发送消息原理说明发送功能实现html部分javascript代码PHP代码 发送消息原理说明 接下来我们发送聊天的文本信息。点击发送按钮的时候,会自动将文本框里的内容发送出去。过程是我们将信息发送到服务器,服务器再转发给对方。文本框的id为msgcont…

DuckDB:密钥管理器及其应用

密钥管理器(Secrets Manager)为所有使用密钥的后端提供了统一的用户界面。密钥信息可以被限定范围,因此不同的存储前缀可以有不同的密钥信息,例如允许在单个查询中连接跨组织的数据。密钥也可以持久化,这样就不需要在每次启动DuckDB时都指定它…

告别Kibana:Elasticsearch 桌面客户端的新变革

告别Kibana:Elasticsearch 桌面客户端的新变革 在大数据处理与分析领域,Elasticsearch 及其相关技术的应用日益广泛。长期以来,Kibana 在数据可视化与查询管理方面占据重要地位,但随着技术的不断发展,用户对于更高效、…

模块化通讯管理机在物联网系统中的应用

安科瑞刘鸿鹏 摘要 随着能源结构转型和智能化电网的推进,电力物联网逐渐成为智能电网的重要组成部分。本文以安科瑞ANet系列智能通信管理机为例,探讨其在电力物联网中的应用,包括数据采集、规约转换、边缘计算、远程控制等技术实践&#…

AAAI 2025论文分享┆一种接近全监督的无训练文档信息抽取方法:SAIL(文中附代码链接)

本推文详细介绍了一篇上海交通大学乐心怡老师课题组被人工智能顶级会议AAAI 2025录用的的最新论文《SAIL: Sample-Centric In-Context Learning for Document Information Extraction》。论文的第一作者为张金钰。该论文提出了一种无需训练的、以样本为中心的、基于上下文学习的…

SAP物料主数据界面增加客制化字段、客制化页签的方式

文章目录 前言一、不增加页签,只增加客制化字段二、增加物料主数据页签 前言 【SAP系统MM模块研究】 #SAP #MM #物料 #客制化 #物料主数据 项目上难免会遇到客户要在物料主数据的界面上,增加新字段的需求。 实现方式有: (1&…

ROS2软件架构全面解析-学习如何设计通信中间件框架

前言 ROS(Robot Operating System) 2 是一个用于开发机器人应用的软件平台,也称为机器人软件开发工具包 (SDK)。 ROS2是ROS1的迭代升级版本 ,最主要的升级点是引入DDS(Data Distribution Service)为基础的…

接口自动化测试流程、工具及其实践

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、接口自动化测试简介 接口自动化测试是指通过编写脚本或使用自动化工具,对软件系统的接口进行测试的过程。接口测试是软件测试中的一种重要测试类…

香橙派5plus单独编译并安装linux内核无法启动的原因分析与解决记录

1 说明 我依照官方手册编译单独编译linux内核,安装后重启出现内核启动失败的问题,编译和安装步骤如下:# 1. 克隆源码 git clone --depth1 -b orange-pi-6.1-rk35xx https://github.com/orangepi-xunlong/linux-orangepi# 2 配置源码 make rockchip_linu…

数据库知识汇总1

一. 数据库系统概述 信息需要媒体(文本、图像视频等)表现出来才能被人类所获取,媒体可以转换成比特或者符号,这些称为数据; 数据/信息的特点:爆炸式增长、无限复制、派生; 数据库是指长期长期…

Win32汇编学习笔记03.RadAsm和补丁

Win32汇编学习笔记03.RadAsm和补丁-C/C基础-断点社区-专业的老牌游戏安全技术交流社区 - BpSend.net 扫雷游戏啊下补丁 在扫雷游戏中,点关闭弹出一个确认框,确认之后再关闭,取消就不关闭 首先第一步就是确认关闭按钮响应的位置,一般都是 WM_CLOSE 的消息 ,消息响应一般都在过…

OSPF特殊区域(open shortest path first LSA Type7)

一、区域介绍 1、Stub区域 Stub区域是一种可选的配置属性。通常来说,Stub区域位于自治系统的边界,例如,只有一 个ABR的非骨干区域。在这些区域中,设备的路由表规模以及路由信息传递的数量都会大量减少。 kill 4 5类type 传递1 …

论文解读之Generative Dense Retrieval: Memory Can Be a Burden

本次论文解读,博主带来生成式稠密检索:记忆可能成为一种负担的论文分享 一、简介 生成式检索根据给定的查询,自回归地检索相关的文档标识符,在小规模的文档库中表现不错,通过使用模型参数记忆文档库,生成…

vue,使用unplugin-auto-import避免反复import,按需自动引入

项目库:https://github.com/unplugin/unplugin-auto-import 参考: https://juejin.cn/post/7012446423367024676 https://cloud.tencent.com/developer/article/2236166 背景: vue3项目中,基本所有页面都会引入vue3框架的api&…

[深度学习] 大模型学习1-大语言模型基础知识

大语言模型(Large Language Model,LLM)是一类基于Transformer架构的深度学习模型,主要用于处理与自然语言相关的各种任务。简单来说,当用户输入文本时,模型会生成相应的回复或结果。它能够完成许多任务&…

OCR图片中文字识别(Tess4j)

文章目录 Tess4J下载 tessdataJava 使用Tess4j 的 demo Tess4J Tess4J 是 Tesseract OCR 引擎的 Java 封装库,它让 Java 项目更轻松地实现 OCR(光学字符识别)功能。 下载 tessdata 下载地址:https://github.com/tesseract-ocr/…

Vue2/Vue3使用DataV

Vue2 注意vue2与3安装DataV命令命令是不同的Vue3 DataV - Vue3 官网地址 注意vue2与3安装DataV命令命令是不同的 vue3vite 与 Vue3webpack 对应安装也不同vue3vite npm install kjgl77/datav-vue3全局引入 // main.ts中全局引入 import { createApp } from vue import Da…

【JVM】总结篇-字节码篇

字节码篇 Java虚拟机的生命周期 JVM的组成 Java虚拟机的体系结构 什么是Java虚拟机 虚拟机:指以软件的方式模拟具有完整硬件系统功能、运行在一个完全隔离环境中的完整计算机系统 ,是物理机的软件实现。常用的虚拟机有VMWare,Visual Box&…

国内Ubuntu环境Docker部署Stable Diffusion入坑记录

国内Ubuntu环境Docker部署Stable Diffusion入坑记录 本文旨在记录使用dockerpython进行部署 stable-diffusion-webui 项目时遇到的一些问题,以及解决方案,原项目地址: https://github.com/AUTOMATIC1111/stable-diffusion-webui 问题一览: …

音频进阶学习九——离散时间傅里叶变换DTFT

文章目录 前言一、DTFT的解释1.DTFT公式2.DTFT右边释义1) 复指数 e − j ω n e^{-j\omega n} e−jωn2)序列与复指数相乘 x [ n ] ∗ e − j ω n x[n]*e^{-j\omega n} x[n]∗e−jωn复指数序列复数的共轭正交正交集 3)复指数序列求和 3.DTF…