文章目录
- 前言
- 1. ImprovedResBlock(改进的残差块)
- 结构组成
- 卷积层1
- 卷积层2
- 跳跃连接(Downsample)
- 前向传播流程
- 主路径
- 跳跃路径
- 残差连接
- 2. EnhancedCNN(主模型)
- 2.1 初始特征提取层
- 功能
- 参数变化
- 2.2 残差块堆叠
- 2.3 分类器
- 3. 数据流示例
- Initial Layer
- ResBlock 1
- ResBlock 2
- ResBlock 3
- AdaptiveAvgPool1d
- Classifier
- 4. 设计优点
- 残差连接
- 逐步下采样
- 自适应池话
- 正则化
- 5. 适用场景
- 时间序列分类
- 音频处理
- 文本分类
前言
这个网络结构是一个改进的卷积神经网络(CNN),专为一维数据(如时间序列、音频信号或文本序列)设计,结合了残差学习(ResNet的思想)和深度卷积特征提取。
1. ImprovedResBlock(改进的残差块)
这是网络的核心组件,通过跳跃连接(Shortcut Connection)缓解梯度消失问题,增强深层网络的训练稳定性。
结构组成
卷积层1
nn.Conv1d(in_channels, out_channels, 5, stride, 2)
5个核大小,步长可调(默认为1),填充2,确保输出长度与输入一致(当stride=1时)。
后接批归一化(BatchNorm1d)和ReLU激活。
卷积层2
nn.Conv1d(out_channels, out_channels, 3, 1, 1)
3个核大小,固定步长1,填充1,保持输出长度不变。
后接批归一化,但无激活函数。
跳跃连接(Downsample)
当输入/输出通道数不同或stride≠1时,通过1x1卷积调整通道数和长度,匹配残差路径的输出形状。
包含Conv1d(步长与主路径一致)和BatchNorm1d。
前向传播流程
主路径
主路径:Conv1 → BN1 → ReLU → Conv2 → BN2。
跳跃路径
跳跃路径:通过downsample调整输入identity的形状。
残差连接
残差连接:主路径输出与跳跃路径相加,再经过ReLU激活。
2. EnhancedCNN(主模型)
整体结构分为特征提取器(卷积层+残差块)和分类器(全连接层)。
2.1 初始特征提取层
nn.Sequential(
nn.Conv1d(input_channels, 64, 7, stride=2, padding=3),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(3, 2, 1)
)
功能
功能:快速降低序列长度,提取低级特征。
参数变化
输入通道数由input_channels升维到64。
卷积核7,步长2,填充3:输出长度为原长度的一半(L//2)。
最大池化(核3,步长2,填充1):长度进一步减半(约L//4)。
2.2 残差块堆叠
nn.Sequential(
ImprovedResBlock(64, 128, stride=2),
ImprovedResBlock(128, 256, stride=2),
ImprovedResBlock(256, 512, stride=2),
nn.AdaptiveAvgPool1d(1)
)
通道数变化:64 → 128 → 256 → 512,逐层翻倍以增强特征表达能力。
每个残差块步长2,长度逐层减半(例如输入长度L//4 → 经过3个块后为L//(4×2³) = L//32)。
最终通过AdaptiveAvgPool1d(1)将每个通道的序列压缩为1个值,输出形状为(batch, 512, 1)。
2.3 分类器
nn.Sequential(
nn.Linear(512, 256),
nn.Dropout(0.5),
nn.ReLU(),
nn.Linear(256, num_classes)
)
将512维特征映射到256维,经Dropout(防过拟合)和ReLU后,输出最终的类别预测。
3. 数据流示例
Initial Layer
假设输入形状为 (batch, input_channels, seq_len):
输出形状:(batch, 64, seq_len//4)(卷积步长2 + 池化步长2)。
ResBlock 1
输入:(64, seq_len//4) → 输出:(128, seq_len//8)(步长2)。
ResBlock 2
输入:(128, seq_len//8) → 输出:(256, seq_len//16)。
ResBlock 3
输入:(256, seq_len//16) → 输出:(512, seq_len//32)。
AdaptiveAvgPool1d
输出形状:(batch, 512, 1) → 展平为 (batch, 512)。
Classifier
最终输出:(batch, num_classes)。
4. 设计优点
残差连接
残差连接:缓解梯度消失,允许构建更深的网络。
逐步下采样
逐步下采样:通过步长2的卷积和池化逐步压缩序列,平衡计算效率和特征保留。
自适应池话
自适应池化:无论输入长度如何,最终输出固定维度,适应变长输入。
正则化
正则化:BatchNorm和Dropout提升泛化能力。
5. 适用场景
时间序列分类
时间序列分类(如传感器数据、ECG信号)。
音频处理
音频处理(如语音识别、声纹识别)。
文本分类
文本分类(需配合嵌入层将文本转为序列)。
通过调整input_channels和num_classes,可灵活适配不同任务。