PyTorch 实现动态输入

news2024/12/4 16:46:58

使用 PyTorch 实现动态输入:支持训练和推理输入维度不一致的 CNN 和 LSTM/GRU 模型

在深度学习中,处理不同大小的输入数据是一个常见的挑战。许多实际应用需要模型能够灵活地处理可变长度的输入。本文将介绍如何使用 PyTorch 实现支持动态输入的 CNN 和 LSTM/GRU 模型,并打印每一层的输入和输出。

  • 卷积神经网络(CNN):CNN 通常用于处理图像数据。它通过卷积层提取局部特征,并能够处理不同大小的输入图像。通过使用全局池化层,CNN 可以将不同大小的特征图转换为固定大小的输出。

  • 长短期记忆网络(LSTM)和门控循环单元(GRU):LSTM 和 GRU 是处理序列数据的 RNN 变体。它们能够捕捉时间序列中的长期依赖关系,并支持可变长度的输入序列。

模型搭建

1. CNN 模型

我们将构建一个简单的 CNN 模型,支持动态输入大小,并打印每一层的输入和输出。

import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicCNN(nn.Module):
    def __init__(self):
        super(DynamicCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应池化层
        self.fc = nn.Linear(32, 10)  # 输出10个类别

    def forward(self, x):
        print(f'Input to CNN: {x.shape}')
        x = F.relu(self.conv1(x))
        print(f'Output after conv1: {x.shape}')
        x = F.relu(self.conv2(x))
        print(f'Output after conv2: {x.shape}')
        x = self.pool(x)
        print(f'Output after pooling: {x.shape}')
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc(x)
        print(f'Output after fc: {x.shape}')
        return x

# 创建模型
cnn_model = DynamicCNN()

# 测试动态输入
input_tensor_cnn = torch.randn(1, 3, 64, 64)  # 输入形状为 (batch_size, channels, height, width)
output_cnn = cnn_model(input_tensor_cnn)
Input to CNN: torch.Size([1, 3, 55, 64])
Output after conv1: torch.Size([1, 16, 53, 62])
Output after conv2: torch.Size([1, 32, 51, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])
Input to CNN: torch.Size([1, 3, 64, 64])
Output after conv1: torch.Size([1, 16, 62, 62])
Output after conv2: torch.Size([1, 32, 60, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])

2. LSTM/GRU 模型

接下来,我们将构建一个支持动态输入的 LSTM 模型,并打印每一层的输入和输出。

import torch
import torch.nn as nn


class DynamicLSTM(nn.Module):
    def __init__(self):
        super(DynamicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
        self.fc = nn.Linear(20, 1)  # 输出一个值

    def forward(self, x):
        print(f'Input to LSTM: {x.shape}')
        x, _ = self.lstm(x)
        print(f'Output after LSTM: {x.shape}')
        x = self.fc(x[:, -1, :])  # 取最后一个时间步的输出
        print(f'Output after fc: {x.shape}')
        return x


# 创建模型
lstm_model = DynamicLSTM()

# 测试动态输入
input_tensor_lstm = torch.randn(5, 15, 10)  # 输入形状为 (batch_size, seq_length, input_size)
output_lstm = lstm_model(input_tensor_lstm)

Input to LSTM: torch.Size([5, 15, 10])
Output after LSTM: torch.Size([5, 15, 20])
Output after fc: torch.Size([5, 1])
Input to LSTM: torch.Size([5, 20, 10])
Output after LSTM: torch.Size([5, 20, 20])
Output after fc: torch.Size([5, 1])

代码说明

  1. DynamicCNN:该模型包含两个卷积层和一个全连接层。使用自适应平均池化层将特征图的大小调整为 (1, 1),从而支持不同大小的输入图像。每一层的输入和输出形状在前向传播中被打印出来。

  2. DynamicLSTM:该模型包含一个 LSTM 层和一个全连接层。LSTM 层能够处理可变长度的输入序列,输出的形状在前向传播中被打印出来。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2253148.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ESP32-S3模组上跑通ES8388(13)

接前一篇文章:ESP32-S3模组上跑通ES8388(12) 二、利用ESP-ADF操作ES8388 2. 详细解析 上一回解析了es8388_init函数中的第6段代码,本回继续往下解析。为了便于理解和回顾,再次贴出es8388_init函数源码,在…

【Mac】安装Gradle

1、说明 Gradle 运行依赖 JVM,需要先安装JDK,Gradle 与 JDK的版本对应参见:Java Compatibility IDEA的版本也是有要求Gradle版本的,二者版本对应关系参见:Third-Party Software and Licenses 本次 Gradle 安装版本为…

根据YAML文件创建Conda环境

YAML(全称为YAML Ain’t Markup Language)是一种轻量级的标记语言。在Python中,YAML文件包含conda环境名和依赖,如图所示。 根据yaml文件创建Conda环境 1.切换路径 找到miniAnaconda或Anaconda,打开Anaconda Powersh…

【分组去重】.NET开源 ORM 框架 SqlSugar 系列

💥 .NET开源 ORM 框架 SqlSugar 系列 🎉🎉🎉 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列…

故障诊断 | Transformer-LSTM组合模型的故障诊断(Matlab)

效果一览 文章概述 故障诊断 | Transformer-LSTM组合模型的故障诊断(Matlab) 源码设计 %% 初始化 clear close all clc disp(此程序务必用2023b及其以上版本的MATLAB!否则会报错!) warning off %

亚马逊云(AWS)使用root用户登录

最近在AWS新开了服务器(EC2),用于学习,遇到一个问题就是默认是用ec2-user用户登录,也需要密钥对。 既然是学习用的服务器,还是想直接用root登录,下面开始修改: 操作系统是&#xff1…

Android笔记【12】脚手架Scaffold和导航Navigation

一、前言 学习课程时,对于自己不懂的点的记录。 对于cy老师第二节课总结。 二、内容 1、PPT介绍scaffold 2、开始代码实操 先新建一个screen包,写一个Homescreen函数,包括四个页面。 再新建一个compenent包,写一个displayText…

HookVip4.0.3 | 可解锁各大应用会员

HookVip是一款可以解锁会员的模块工具,需要搭配相应框架结合使用。这款插件工具支持多种框架如LSPosed、LSPatch、太极、应用转生等,并且完全免费,占用内存小。支持的软件包括now要想、神奇脑波、塔罗牌占卜、爱剪辑、人人视频、咪萌桌面宠物…

猎板 PCB特殊工艺:铸就电子行业核心竞争力新高度

在当今竞争激烈且技术驱动的电子制造领域,印制电路板(PCB)作为电子产品的关键基石,其特殊工艺的发展水平直接影响着整个行业的创新步伐与产品品质。猎板 PCB 凭借在厚铜板、孔口铺铜、HDI 板、大尺寸板以及高频高速板等特殊工艺方…

【教学类-43-25】20241203 数独3宫格的所有可能-使用模版替换(12套样式,空1格-空8格,每套510张,共6120小图)

前期做数独惨宫格的所有排列,共有12套样式,空1格-空8格,每套510张,共6120小图) 【教学类-43-24】20241127 数独3宫格的所有可能(12套样式,空1格-空8格,每套510张,共6120…

Redis+Caffeine 多级缓存数据一致性解决方案

RedisCaffeine 多级缓存数据一致性解决方案 背景 之前写过一篇文章RedisCaffeine 实现两级缓存实战,文章提到了两级缓存RedisCaffeine可以解决缓存雪等问题也可以提高接口的性能,但是可能会出现缓存一致性问题。如果数据频繁的变更,可能会导…

echarts地图立体效果,echarts地图点击事件,echarts地图自定义自定义tooltip

一.地图立体效果 方法1:两层地图叠加 实现原理:geo数组中放入两个地图对象,通过修改zlevel属性以及top,left,right,bottom形成视觉差 配置项参考如下代码: geo: [{zlevel: 2,top: 96,map: map,itemStyle: {color: #091A51ee,opacity: 1,borderWidth: 2,borderColor: #16BAFA…

D87【python 接口自动化学习】- pytest基础用法

day87 pytest运行参数 -m -k 学习日期:20241203 学习目标:pytest基础用法 -- pytest运行参数-m -k 学习笔记: 常用运行参数 pytest运行参数-m -k pytest -m 执行特定的测试用例,markers最好使用英文 [pytest] testpaths./te…

总结拓展十七:特殊采购业务——委外业务

SAP中委外采购业务,又称供应商分包(或外协、转包、、外包、托外等),是企业将部分生产任务委托给外部供应商/集团其他分子公司完成的一种特殊采购业务模式。 委外业务主要有2大类型,分别是标准委外(委外采购…

ESP8266作为TCP客户端或者服务器使用

ESP8266模块,STA模式(与手机搭建TCP通讯,EPS8266为服务端)_esp8266作为station-CSDN博客 ESP8266模块,STA模式(与电脑搭建TCP通讯,ESP8266 为客户端)_esp8266 sta 连接tcp-CSDN博客…

ATTCK红队评估实战靶场(四)

靶机链接:http://vulnstack.qiyuanxuetang.net/vuln/detail/6/ 环境搭建 新建两张仅主机网卡,一张192.168.183.0网段(内网网卡),一张192.168.157.0网段(模拟外网网段),然后按照拓补…

C 语言 “神秘魔杖”—— 指针初相识,解锁编程魔法大门(一)

文章目录 一、概念1、取地址操作符(&)2、解引用操作符(*)3、指针变量1、 声明和初始化2、 用途 二、内存和地址三、指针变量类型的意义1、 指针变量类型的基本含义2、 举例说明不同类型指针变量的意义 四、const修饰指针1、co…

封装loding加载动画的请求

图片 /*** Loading 状态管理类*/ export class Loading {constructor(timer300) {this.value falsethis.timer timer}/*** 执行异步操作并自动管理 loading 状态* param {Promise|Function|any} target - Promise、函数或其他值* returns {Promise} - 返回请求结果*/async r…

人形机器人训练、机器臂远程操控、VR游戏交互、影视动画制作,一副手套全部解决!

广州虚拟动力基于自研技术推出了多节点mHand Pro动捕数据手套,其最大的特点就是功能集成与高精度捕捉,可以用于人形机器人训练、机器臂远程操控、VR游戏交互、影视动画制作等多种场景。 一、人形机器人训练 mHand Pro动捕数据手套双手共装配16个9轴惯性…

Nginx Web服务器管理、均衡负载、访问控制与跨域问题

Nginx Web 服务器的均衡负载、访问控制与跨域问题 Nginx 的配置 1. 安装Nginx 首先安装Nginx apt install nginx -ycaccpurgatory-v:~$ sudo apt install nginx [sudo] password for cacc: Reading package lists... Done Building dependency tree... Done Reading state i…