- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客R5中的内容,为了便于自己整理总结起名为R2
- 🍖 原作者:K同学啊 | 接辅导、项目定制
目录
- 0. 总结
- 1. RNN介绍
- a. 什么是 RNN?
- RNN 的一般应用场景
- b. 传统 RNN 的基本结构
- 关键特征
- c. RNN 的优势与局限
- 优势
- 局限与改进
- d. RNN 的常见变体:LSTM 和 GRU
- LSTM (Long Short-Term Memory)
- GRU (Gated Recurrent Unit)
- e. RNN 的应用案例
- f. RNN 在 PyTorch 中的实现方式
- g. 如何更进一步学习 RNN?
- h. 总结
- 2. 数据导入
- 3. 数据探索性分析
- a. 数据相关性探索
- b. 是否会下雨
- c. 地理位置与下雨的关系
- d. 湿度和压力对下雨的影响
- e. 气温对下雨的影响
- 4. 数据预处理
- 5. 构建数据集
- 6. 定义模型
- 7. 初始化模型与优化器
- 8. 训练函数
- 9. 测试函数
- 10. 执行训练
- 11. 过程可视化
0. 总结
数据导入及处理部分:在 PyTorch 中,我们通常先将 NumPy 数组转换为 torch.Tensor,再封装到 TensorDataset 或自定义的 Dataset 里,然后用 DataLoader 按批次加载。
模型构建部分:RNN
设置超参数:在这之前需要定义损失函数,学习率(动态学习率),以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。
定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。
定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。
训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。
结果可视化
模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。
需要改进优化的地方:确保模型和数据的一致性,都存到GPU或者CPU;注意numclasses不要直接用默认的1000,需要根据实际数据集改进;实例化模型也要注意numclasses这个参数;此外注意测试模型需要用(3,224,224)3表示通道数,这和tensorflow定义的顺序是不用的(224,224,3),做代码转换时需要注意。
1. RNN介绍
下面是对 RNN(Recurrent Neural Network) 的一个循序渐进、相对通俗的介绍,帮助你从原理上理解 RNN 的本质与应用,希望对你有所帮助。
a. 什么是 RNN?
RNN,全称 Recurrent Neural Network,即“循环神经网络”。它是一类专门处理序列数据的神经网络模型,与传统的前馈网络(如全连接网络 MLP、卷积网络 CNN 等)最大的区别在于:
- 序列性:RNN 可以在序列的时间步之间传递信息,具备“记忆”先前输入的能力。
- 循环结构:在每一个时间步,网络都会基于当前输入和上一时刻的隐藏状态来更新当前隐藏状态,然后输出结果。
RNN 的一般应用场景
- 自然语言处理(NLP):如情感分析、文本分类、机器翻译、文本生成等。
- 时间序列预测:如股票预测、温度预测、信号处理等。
- 语音识别或合成:处理音频序列。
b. 传统 RNN 的基本结构
以下是一个最基础(经典版)的 RNN 结构示意:
┌───────┐ ┌───────┐ ┌───────┐
│x(t-1) │ │x(t) │ │x(t+1) │ ← 输入序列
└──┬────┘ └──┬────┘ └──┬────┘
│ │ │
┌─▼──────────────▼──────────────▼─────────────────────────┐
│ RNN 单元 (循环体) │
│ │
│ h(t-1) ──┐ ┌─────────┐ ┌─────────┐ │
│ │ │激活函数 f│ │激活函数 g│ │
│ x(t), h(t-1) → │ 线性运算 → │ (如 tanh) → h(t) │
│ │ └─────────┘ └─────────┘ │
└────────────┴─────────────────────────────────────────────┘
↑
通过时间传递
(隐藏状态 h)
- 输入序列:( x(1), x(2), …, x(T) )
- 隐藏状态:( h(t) ) 表示网络在时间步 ( t ) 的内部记忆。
- 更新公式(经典 RNN 的简单形式):
[
h(t) = \sigma(W_{hh} \cdot h(t-1) + W_{xh} \cdot x(t) + b_h)
]
其中 (\sigma) 通常是一个非线性激活函数,如 (\tanh) 或 (\text{ReLU}) 等。
关键特征
-
循环(Recurrent)
- RNN 通过将过去的隐藏状态 ( h(t-1) ) 反复输入到网络,与当前输入 ( x(t) ) 一起决策新的隐藏状态 ( h(t) )。因此它在时间序列上“循环”展开。
-
参数共享(Parameter Sharing)
- 对于序列中每个时间步,RNN 使用相同的一组权重((W_{hh}, W_{xh}) 等),这与一般的多层感知器(MLP)不同,MLP 每一层都会有一组新的权重。
-
序列建模
- 借助隐藏状态的更新,RNN 在一定程度上能够“记住”之前输入的信息,从而可以用来处理依赖于上下文或时间顺序的任务(如语言模型,每个单词与前面单词息息相关)。
c. RNN 的优势与局限
优势
- 适合序列数据:相比于传统的全连接网络,RNN 能够更好地处理变长的序列输入,捕捉序列中的时序依赖关系。
- 参数共享:节省模型参数,防止过度膨胀。
局限与改进
- 长期依赖问题:经典 RNN 里,随着序列长度增大,早期输入的信息往往无法传播到后面时间步,会导致梯度消失或梯度爆炸。
- 训练效率:由于存在序列展开 + 反向传播(BPTT: Back Propagation Through Time)的特殊性,训练速度通常慢于并行度高的卷积网络。
- 改进模型:
- LSTM(Long Short-Term Memory)
- GRU(Gated Recurrent Unit)
这两种模型通过门控机制(忘记门、输入门、输出门等)来缓解或部分解决长期依赖问题,在实际中广泛使用。
d. RNN 的常见变体:LSTM 和 GRU
由于传统 RNN 在对长序列进行建模时,容易遗忘早期信息,为了解决这个问题,人们提出了带有 “门控” 机制的循环神经网络结构。其中最典型的就是 LSTM 和 GRU。
LSTM (Long Short-Term Memory)
- 由 记忆单元(Cell state)和 门控机制(input gate、forget gate、output gate)来控制信息的流动,保留长期的梯度信息,从而缓解梯度消失问题。
- 在很多 NLP 任务中,LSTM 大多表现优于传统 RNN。
GRU (Gated Recurrent Unit)
- 结构上比 LSTM 更简化,只有 更新门 和 重置门,虽然结构更简单,但也能保留一定的长期依赖能力。
- 在某些任务中,GRU 的性能与 LSTM 不相上下,而且训练速度更快。
e. RNN 的应用案例
-
语言模型
- 给定前面的单词,预测下一个单词;或给定一段前文,生成下一段文本。
- 例如早期的机器翻译系统,输入序列是原语言单词,输出序列是翻译后的目标语言单词。
- 现在更多使用了 Transformer 这种基于自注意力机制的模型,但 RNN 依然是重要的基石概念。
-
序列分类
- 对一段文本或语音做分类,如情感分析(正向/负向)、语音识别(识别说的是哪一句话)等。
-
时间序列预测
- 比如股票预测、流量预测、天气预测,通过过去若干时刻的数据预测未来走向。
f. RNN 在 PyTorch 中的实现方式
在 PyTorch 里,最常见的循环网络层包括:
nn.RNN
:经典单层 RNN,可选激活函数tanh
或ReLU
。nn.LSTM
:LSTM 结构nn.GRU
:GRU 结构
输入通常需要形状 (batch_size, seq_len, input_size)
(当 batch_first=True
时)。
输出需要自己选择:
- 如果只需要最后一个时间步的输出,往往取
output[:, -1, :]
; - 如果需要所有时间步的输出(比如生成序列时),则直接使用
output
。 - 训练时要记得将 hidden state(以及 cell state)正确地传递或重置。
g. 如何更进一步学习 RNN?
- 从小例子入手:
- 用 RNN 来解决简单的序列学习任务(例如正弦波预测、小规模字符级语言模型),查看网络是如何随时间迭代的。
- 阅读论文与教程:
- LSTM 的原始论文 (Hochreiter & Schmidhuber, 1997)
- GRU (Cho et al., 2014)
- 深入理解门控机制,体会为什么能让 RNN 更好地记住/遗忘信息。
- 与 Transformer 对比:
- 在大多数 NLP 任务上,目前已被 Transformer 结构占据主流,但 RNN 思想仍是许多研究的基础。理解 RNN 有助于理解注意力机制为什么行之有效。
- 深入到框架实现:
- 看 PyTorch 中
nn.RNN
、nn.LSTM
、nn.GRU
的源代码或官方文档,了解参数含义及前向、后向的具体计算流程。
- 看 PyTorch 中
h. 总结
- 核心思想:RNN 可以“循环”地将过去的信息传递到现在,从而在一定程度上捕捉序列数据的依赖关系。
- 传统 RNN 的问题:容易出现梯度消失或爆炸,难以捕捉长程依赖。
- 常见改进:LSTM、GRU 等门控结构缓解了长期依赖难题,也成为 RNN 家族的主力。
- 现今趋势:NLP 等领域更多使用 Transformer,但 RNN 在许多对序列长度不太长的场合依旧可以使用,而且对初学者理解神经网络的“记忆”能力非常有帮助。
如果你刚开始学习,可以:
- 多动手调试:写一些小规模 RNN 代码,训练简单的序列数据,观察 loss 和隐藏状态如何变化。
- 多画图:用纸笔画 RNN 在时序上的展开图,有助于理解反向传播的流程。
- 分门别类:清楚哪些任务用 LSTM/GRU,哪些任务需要 CNN 或 Transformer,知道各种模型的优势与局限。
2. 数据导入
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import copy
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error , mean_absolute_percentage_error , mean_squared_error
data = pd.read_csv("./data/weatherAUS.csv")
df = data.copy()
data.head()
Date | Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2008-12-01 | Albury | 13.4 | 22.9 | 0.6 | NaN | NaN | W | 44.0 | W | ... | 71.0 | 22.0 | 1007.7 | 1007.1 | 8.0 | NaN | 16.9 | 21.8 | No | No |
1 | 2008-12-02 | Albury | 7.4 | 25.1 | 0.0 | NaN | NaN | WNW | 44.0 | NNW | ... | 44.0 | 25.0 | 1010.6 | 1007.8 | NaN | NaN | 17.2 | 24.3 | No | No |
2 | 2008-12-03 | Albury | 12.9 | 25.7 | 0.0 | NaN | NaN | WSW | 46.0 | W | ... | 38.0 | 30.0 | 1007.6 | 1008.7 | NaN | 2.0 | 21.0 | 23.2 | No | No |
3 | 2008-12-04 | Albury | 9.2 | 28.0 | 0.0 | NaN | NaN | NE | 24.0 | SE | ... | 45.0 | 16.0 | 1017.6 | 1012.8 | NaN | NaN | 18.1 | 26.5 | No | No |
4 | 2008-12-05 | Albury | 17.5 | 32.3 | 1.0 | NaN | NaN | W | 41.0 | ENE | ... | 82.0 | 33.0 | 1010.8 | 1006.0 | 7.0 | 8.0 | 17.8 | 29.7 | No | No |
5 rows × 23 columns
data.describe()
MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustSpeed | WindSpeed9am | WindSpeed3pm | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 143975.000000 | 144199.000000 | 142199.000000 | 82670.000000 | 75625.000000 | 135197.000000 | 143693.000000 | 142398.000000 | 142806.000000 | 140953.000000 | 130395.00000 | 130432.000000 | 89572.000000 | 86102.000000 | 143693.000000 | 141851.00000 |
mean | 12.194034 | 23.221348 | 2.360918 | 5.468232 | 7.611178 | 40.035230 | 14.043426 | 18.662657 | 68.880831 | 51.539116 | 1017.64994 | 1015.255889 | 4.447461 | 4.509930 | 16.990631 | 21.68339 |
std | 6.398495 | 7.119049 | 8.478060 | 4.193704 | 3.785483 | 13.607062 | 8.915375 | 8.809800 | 19.029164 | 20.795902 | 7.10653 | 7.037414 | 2.887159 | 2.720357 | 6.488753 | 6.93665 |
min | -8.500000 | -4.800000 | 0.000000 | 0.000000 | 0.000000 | 6.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 980.50000 | 977.100000 | 0.000000 | 0.000000 | -7.200000 | -5.40000 |
25% | 7.600000 | 17.900000 | 0.000000 | 2.600000 | 4.800000 | 31.000000 | 7.000000 | 13.000000 | 57.000000 | 37.000000 | 1012.90000 | 1010.400000 | 1.000000 | 2.000000 | 12.300000 | 16.60000 |
50% | 12.000000 | 22.600000 | 0.000000 | 4.800000 | 8.400000 | 39.000000 | 13.000000 | 19.000000 | 70.000000 | 52.000000 | 1017.60000 | 1015.200000 | 5.000000 | 5.000000 | 16.700000 | 21.10000 |
75% | 16.900000 | 28.200000 | 0.800000 | 7.400000 | 10.600000 | 48.000000 | 19.000000 | 24.000000 | 83.000000 | 66.000000 | 1022.40000 | 1020.000000 | 7.000000 | 7.000000 | 21.600000 | 26.40000 |
max | 33.900000 | 48.100000 | 371.000000 | 145.000000 | 14.500000 | 135.000000 | 130.000000 | 87.000000 | 100.000000 | 100.000000 | 1041.00000 | 1039.600000 | 9.000000 | 9.000000 | 40.200000 | 46.70000 |
data.dtypes
Date object
Location object
MinTemp float64
MaxTemp float64
Rainfall float64
Evaporation float64
Sunshine float64
WindGustDir object
WindGustSpeed float64
WindDir9am object
WindDir3pm object
WindSpeed9am float64
WindSpeed3pm float64
Humidity9am float64
Humidity3pm float64
Pressure9am float64
Pressure3pm float64
Cloud9am float64
Cloud3pm float64
Temp9am float64
Temp3pm float64
RainToday object
RainTomorrow object
dtype: object
3. 数据探索性分析
#将数据转换为日期时间格式
data['Date'] = pd.to_datetime(data['Date'])
data['year'] = data['Date'].dt.year
data['Month'] = data['Date'].dt.month
data['day'] = data['Date'].dt.day
data.head()
Date | Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | ... | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | year | Month | day | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2008-12-01 | Albury | 13.4 | 22.9 | 0.6 | NaN | NaN | W | 44.0 | W | ... | 1007.1 | 8.0 | NaN | 16.9 | 21.8 | No | No | 2008 | 12 | 1 |
1 | 2008-12-02 | Albury | 7.4 | 25.1 | 0.0 | NaN | NaN | WNW | 44.0 | NNW | ... | 1007.8 | NaN | NaN | 17.2 | 24.3 | No | No | 2008 | 12 | 2 |
2 | 2008-12-03 | Albury | 12.9 | 25.7 | 0.0 | NaN | NaN | WSW | 46.0 | W | ... | 1008.7 | NaN | 2.0 | 21.0 | 23.2 | No | No | 2008 | 12 | 3 |
3 | 2008-12-04 | Albury | 9.2 | 28.0 | 0.0 | NaN | NaN | NE | 24.0 | SE | ... | 1012.8 | NaN | NaN | 18.1 | 26.5 | No | No | 2008 | 12 | 4 |
4 | 2008-12-05 | Albury | 17.5 | 32.3 | 1.0 | NaN | NaN | W | 41.0 | ENE | ... | 1006.0 | 7.0 | 8.0 | 17.8 | 29.7 | No | No | 2008 | 12 | 5 |
5 rows × 26 columns
data.drop('Date',axis=1,inplace=True)
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
'Temp3pm', 'RainToday', 'RainTomorrow', 'year', 'Month', 'day'],
dtype='object')
a. 数据相关性探索
plt.figure(figsize=(15,13))
# data.corr()表示了data中的两个变量之间的相关性
ax = sns.heatmap(data.corr(), square=True, annot=True, fmt='.2f')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
plt.show()
b. 是否会下雨
# 设置样式和调色板
sns.set(style="whitegrid", palette="Set2")
# 创建一个 1 行 2 列的图像布局
fig, axes = plt.subplots(1, 2, figsize=(10, 4)) # 图形尺寸调大 (10, 4)
# 图表标题样式
title_font = {'fontsize': 14, 'fontweight': 'bold', 'color': 'darkblue'}
# 第一张图:RainTomorrow
sns.countplot(x='RainTomorrow', data=data, ax=axes[0], edgecolor='black') # 添加边框
axes[0].set_title('Rain Tomorrow', fontdict=title_font) # 设置标题
axes[0].set_xlabel('Will it Rain Tomorrow?', fontsize=12) # X轴标签
axes[0].set_ylabel('Count', fontsize=12) # Y轴标签
axes[0].tick_params(axis='x', labelsize=11) # X轴刻度字体大小
axes[0].tick_params(axis='y', labelsize=11) # Y轴刻度字体大小
# 第二张图:RainToday
sns.countplot(x='RainToday', data=data, ax=axes[1], edgecolor='black') # 添加边框
axes[1].set_title('Rain Today', fontdict=title_font) # 设置标题
axes[1].set_xlabel('Did it Rain Today?', fontsize=12) # X轴标签
axes[1].set_ylabel('Count', fontsize=12) # Y轴标签
axes[1].tick_params(axis='x', labelsize=11) # X轴刻度字体大小
axes[1].tick_params(axis='y', labelsize=11) # Y轴刻度字体大小
sns.despine() # 去除图表顶部和右侧的边框
plt.tight_layout() # 调整布局,避免图形之间的重叠
plt.show()
x=pd.crosstab(data['RainTomorrow'],data['RainToday'])
x
RainToday | No | Yes |
---|---|---|
RainTomorrow | ||
No | 92728 | 16858 |
Yes | 16604 | 14597 |
y=x/x.transpose().sum().values.reshape(2,1)*100
y
RainToday | No | Yes |
---|---|---|
RainTomorrow | ||
No | 84.616648 | 15.383352 |
Yes | 53.216243 | 46.783757 |
-
如果今天不下雨,那么明天下雨的机会 = 53.22%
-
如果今天下雨明天下雨的机会 = 46.78%
y.plot(kind="bar",figsize=(4,3),color=['#006666','#d279a6']);
c. 地理位置与下雨的关系
x=pd.crosstab(data['Location'],data['RainToday'])
# 获取每个城市下雨天数和非下雨天数的百分比
y=x/x.transpose().sum().values.reshape((-1, 1))*100
# 按每个城市的雨天百分比排序
y=y.sort_values(by='Yes',ascending=True )
color=['#cc6699','#006699','#006666','#862d86','#ff9966' ]
y.Yes.plot(kind="barh",figsize=(15,20),color=color)
<Axes: ylabel='Location'>
位置影响下雨,对于 Portland 来说,有 36% 的时间在下雨,而对于 Woomers 来说,只有6%的时间在下雨
d. 湿度和压力对下雨的影响
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
'Temp3pm', 'RainToday', 'RainTomorrow', 'year', 'Month', 'day'],
dtype='object')
plt.figure(figsize=(8,6))
sns.scatterplot(data=data,x='Pressure9am',
y='Pressure3pm',hue='RainTomorrow');
plt.figure(figsize=(8,6))
sns.scatterplot(data=data,x='Humidity9am',
y='Humidity3pm',hue='RainTomorrow');
低压与高湿度会增加第二天下雨的概率,尤其是下午 3 点的空气湿度。
e. 气温对下雨的影响
plt.figure(figsize=(8,6))
sns.scatterplot(x='MaxTemp', y='MinTemp',
data=data, hue='RainTomorrow');
4. 数据预处理
处理缺损值
# 每列中缺失数据的百分比
data.isnull().sum()/data.shape[0]*100
Location 0.000000
MinTemp 1.020899
MaxTemp 0.866905
Rainfall 2.241853
Evaporation 43.166506
Sunshine 48.009762
WindGustDir 7.098859
WindGustSpeed 7.055548
WindDir9am 7.263853
WindDir3pm 2.906641
WindSpeed9am 1.214767
WindSpeed3pm 2.105046
Humidity9am 1.824557
Humidity3pm 3.098446
Pressure9am 10.356799
Pressure3pm 10.331363
Cloud9am 38.421559
Cloud3pm 40.807095
Temp9am 1.214767
Temp3pm 2.481094
RainToday 2.241853
RainTomorrow 2.245978
year 0.000000
Month 0.000000
day 0.000000
dtype: float64
# 在该列中随机选择数进行填充
lst=['Evaporation','Sunshine','Cloud9am','Cloud3pm']
for col in lst:
fill_list = data[col].dropna()
data[col] = data[col].fillna(pd.Series(np.random.choice(fill_list, size=len(data.index))))
s = (data.dtypes == "object")
object_cols = list(s[s].index)
object_cols
['Location',
'WindGustDir',
'WindDir9am',
'WindDir3pm',
'RainToday',
'RainTomorrow']
# inplace=True:直接修改原对象,不创建副本
# data[i].mode()[0] 返回频率出现最高的选项,众数
for i in object_cols:
data[i].fillna(data[i].mode()[0], inplace=True)
t = (data.dtypes == "float64")
num_cols = list(t[t].index)
num_cols
['MinTemp',
'MaxTemp',
'Rainfall',
'Evaporation',
'Sunshine',
'WindGustSpeed',
'WindSpeed9am',
'WindSpeed3pm',
'Humidity9am',
'Humidity3pm',
'Pressure9am',
'Pressure3pm',
'Cloud9am',
'Cloud3pm',
'Temp9am',
'Temp3pm']
# .median(), 中位数
for i in num_cols:
data[i].fillna(data[i].median(), inplace=True)
data.isnull().sum()
Location 0
MinTemp 0
MaxTemp 0
Rainfall 0
Evaporation 0
Sunshine 0
WindGustDir 0
WindGustSpeed 0
WindDir9am 0
WindDir3pm 0
WindSpeed9am 0
WindSpeed3pm 0
Humidity9am 0
Humidity3pm 0
Pressure9am 0
Pressure3pm 0
Cloud9am 0
Cloud3pm 0
Temp9am 0
Temp3pm 0
RainToday 0
RainTomorrow 0
year 0
Month 0
day 0
dtype: int64
5. 构建数据集
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
for i in object_cols:
data[i] = label_encoder.fit_transform(data[i])
X = data.drop(['RainTomorrow','day'],axis=1).values
y = data['RainTomorrow'].values
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25,random_state=101)
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
# 创建pytorch dataset 和 dataloader
"""
在 PyTorch 中,我们通常先将 NumPy 数组转换为 torch.Tensor,
再封装到 TensorDataset 或自定义的 Dataset 里,然后用 DataLoader 按批次加载。
"""
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
X_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test.reshape(X_test.shape[0],X_test.shape[1],1)
# 如果要做二分类 + Sigmoid + nn.BCELoss,那么标签可以用 float32
# 如果要做多分类(例如 softmax + CrossEntropy),则需把标签转为 long
y_train = y_train.astype(np.float32) # 二分类: float32
y_test = y_test.astype(np.float32) # 二分类: float32
# 转换为张量
X_train_tensor = torch.from_numpy(X_train).float() # shape:[samples, 13, 1]
y_train_tensor = torch.from_numpy(y_train) # shape:[samples]
X_test_tensor = torch.from_numpy(X_test).float()
y_test_tensor = torch.from_numpy(y_test)
# 如果后续需要在训练中对标签执行 pred>0.5 判定,可以保持 y 的 shape=[samples] 即可
# 也可 reshape([-1,1]) 保持和网络输出尺寸一致,不过这并非必须。
# y_train_tensor = y_train_tensor.view(-1,1)
# y_test_tensor = y_test_tensor.view(-1,1)
# 用 TensorDataset 直接封装
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
# 创建 DataLoader
batch_size = 32
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
6. 定义模型
### 构建RNN模型
# -----------------------------
# 1. 定义模型结构
# -----------------------------
class SimpleRNNModel(nn.Module):
def __init__(self):
super(SimpleRNNModel, self).__init__()
# TensorFlow 中 input_shape=(13,1),即序列长度 seq_len = 13,特征维度 input_dim = 1
# PyTorch RNN 层若设置 batch_first=True:
# 输入张量形状: (batch_size, seq_len, input_dim)
# 输出张量形状: (batch_size, seq_len, hidden_size)
self.rnn = nn.RNN(
input_size=1, # 对应 TF 的 input_dim=1
hidden_size=200, # 对应 TF 的 RNN(200)
batch_first=True,
nonlinearity='relu' # 对应 TF 的 activation='relu'
)
self.fc1 = nn.Linear(200, 100) # 对应 Dense(100, activation='relu')
self.fc2 = nn.Linear(100, 1) # 对应 Dense(1, activation='sigmoid')
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: [batch_size, 13, 1]
# RNN 输出: output, hidden
# output shape = [batch_size, seq_len, hidden_size]
# hidden shape = [num_layers, batch_size, hidden_size]
out, hidden = self.rnn(x)
# 取最后一个 time_step 的输出, 与 TensorFlow 里 SimpleRNN 的默认行为一致
out = out[:, -1, :] # shape: [batch_size, hidden_size]
# 与 Dense(100, relu)
out = F.relu(self.fc1(out)) # [batch_size, 100]
# 与 Dense(1, sigmoid)
out = self.sigmoid(self.fc2(out)) # [batch_size, 1]
return out
7. 初始化模型与优化器
# -----------------------------
# 2. 初始化模型与优化器
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleRNNModel().to(device)
print(model)
# 与 TF 中 loss='binary_crossentropy' 对应,PyTorch 用 BCE:nn.BCELoss
loss_fn = nn.BCELoss()
# 多分类问题使用nn.CrossEntropyLoss()
# criterion = nn.CrossEntropyLoss()
learn_rate = 1e-4
# learn_rate = 3e-4
lambda1 = lambda epoch:(0.92**(epoch//2))
optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1) # 选定调整方法
SimpleRNNModel(
(rnn): RNN(1, 200, batch_first=True)
(fc1): Linear(in_features=200, out_features=100, bias=True)
(fc2): Linear(in_features=100, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
8. 训练函数
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset) # 训练集大小
num_batches = len(dataloader) # 批次数目
train_loss, train_acc = 0, 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# 计算预测
pred = model(X).view(-1) # [batch_size]
loss = loss_fn(pred, y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录acc与loss
# 情况1: 如果是多分类(N>1), pred.shape=[batch_size, N],可以用argmax(1).
# train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
# 情况2: 如果是二分类且只有1个输出(使用 Sigmoid),则 pred.shape=[batch_size,1],
# 那么可用 (pred>0.5) 转为0/1来比较:
pred_label = (pred > 0.5).long() # [batch_size]
train_acc += (pred_label == y.long()).sum().item()
train_loss += loss.item()
train_acc /= size
train_loss /= num_batches
return train_acc, train_loss
9. 测试函数
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_acc, test_loss = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# 计算预测
pred = model(X).view(-1) # [batch_size]
loss = loss_fn(pred, y)
# 情况1: 多分类(N>1):
# test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
# 情况2: 二分类单输出:
pred_label = (pred > 0.5).long() # [batch_size]
# test_acc += (pred_label.view(-1) == y).type(torch.float).sum().item()
test_acc += (pred_label == y.long()).sum().item()
test_loss += loss.item()
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
10. 执行训练
# -----------------------------
# 打印可用 GPU 信息
# -----------------------------
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
print(f"Initial Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
print(f"Initial Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:
print("No GPU available. Using CPU.")
# -----------------------------
# 训练主循环
# -----------------------------
epochs = 60
train_acc_list = []
train_loss_list = []
test_acc_list = []
test_loss_list = []
best_acc = 0.0
best_model = None
for epoch in range(epochs):
# 更新学习率——使用自定义学习率时使用
# adjust_learning_rate(optimizer,epoch,learn_rate)
# 切换为训练模式
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
# 更新学习率
scheduler.step() # 更新学习率——调用官方动态学习率时使用
# 切换为评估模式
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
# 保存最佳模型
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model = copy.deepcopy(model)
train_acc_list.append(epoch_train_acc)
train_loss_list.append(epoch_train_loss)
test_acc_list.append(epoch_test_acc)
test_loss_list.append(epoch_test_loss)
# 当前学习率
lr = optimizer.state_dict()['param_groups'][0]['lr']
template = (
'Epoch:{:2d}, '
'Train_acc:{:.1f}%, Train_loss:{:.3f}, '
'Test_acc:{:.1f}%, Test_loss:{:.3f}, '
'Lr:{:.2E}'
)
print(template.format(
epoch+1,
epoch_train_acc*100, epoch_train_loss,
epoch_test_acc*100, epoch_test_loss,
lr
))
# 实时监控 GPU 状态
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
print(f"GPU {i} Usage:")
print(f" Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
print(f" Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
print(f" Max Memory Allocated: {torch.cuda.max_memory_allocated(i)/1024**2:.2f} MB")
print(f" Max Memory Reserved: {torch.cuda.max_memory_reserved(i)/1024**2:.2f} MB")
print('Done. Best test acc: ', best_acc)
No GPU available. Using CPU.
Epoch: 1, Train_acc:80.1%, Train_loss:0.460, Test_acc:82.7%, Test_loss:0.397, Lr:1.00E-04
Epoch: 2, Train_acc:83.4%, Train_loss:0.387, Test_acc:83.8%, Test_loss:0.374, Lr:9.20E-05
Epoch: 3, Train_acc:83.9%, Train_loss:0.375, Test_acc:84.1%, Test_loss:0.367, Lr:9.20E-05
Epoch: 4, Train_acc:83.9%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.365, Lr:8.46E-05
Epoch: 5, Train_acc:84.1%, Train_loss:0.368, Test_acc:83.9%, Test_loss:0.375, Lr:8.46E-05
Epoch: 6, Train_acc:84.3%, Train_loss:0.366, Test_acc:84.3%, Test_loss:0.364, Lr:7.79E-05
Epoch: 7, Train_acc:84.3%, Train_loss:0.365, Test_acc:84.3%, Test_loss:0.363, Lr:7.79E-05
Epoch: 8, Train_acc:84.3%, Train_loss:0.364, Test_acc:84.3%, Test_loss:0.362, Lr:7.16E-05
Epoch: 9, Train_acc:84.3%, Train_loss:0.364, Test_acc:84.4%, Test_loss:0.362, Lr:7.16E-05
Epoch:10, Train_acc:84.4%, Train_loss:0.362, Test_acc:84.3%, Test_loss:0.363, Lr:6.59E-05
Epoch:11, Train_acc:84.3%, Train_loss:0.361, Test_acc:84.4%, Test_loss:0.363, Lr:6.59E-05
Epoch:12, Train_acc:84.4%, Train_loss:0.361, Test_acc:84.4%, Test_loss:0.359, Lr:6.06E-05
Epoch:13, Train_acc:84.5%, Train_loss:0.360, Test_acc:84.4%, Test_loss:0.362, Lr:6.06E-05
Epoch:14, Train_acc:84.4%, Train_loss:0.360, Test_acc:84.5%, Test_loss:0.359, Lr:5.58E-05
Epoch:15, Train_acc:84.5%, Train_loss:0.358, Test_acc:84.4%, Test_loss:0.358, Lr:5.58E-05
Epoch:16, Train_acc:84.5%, Train_loss:0.358, Test_acc:84.5%, Test_loss:0.361, Lr:5.13E-05
Epoch:17, Train_acc:84.6%, Train_loss:0.357, Test_acc:84.5%, Test_loss:0.358, Lr:5.13E-05
Epoch:18, Train_acc:84.6%, Train_loss:0.357, Test_acc:84.6%, Test_loss:0.357, Lr:4.72E-05
Epoch:19, Train_acc:84.6%, Train_loss:0.356, Test_acc:84.6%, Test_loss:0.357, Lr:4.72E-05
Epoch:20, Train_acc:84.7%, Train_loss:0.356, Test_acc:84.6%, Test_loss:0.356, Lr:4.34E-05
Epoch:21, Train_acc:84.6%, Train_loss:0.355, Test_acc:84.6%, Test_loss:0.356, Lr:4.34E-05
Epoch:22, Train_acc:84.6%, Train_loss:0.355, Test_acc:84.6%, Test_loss:0.356, Lr:4.00E-05
Epoch:23, Train_acc:84.7%, Train_loss:0.354, Test_acc:84.6%, Test_loss:0.356, Lr:4.00E-05
Epoch:24, Train_acc:84.7%, Train_loss:0.354, Test_acc:84.5%, Test_loss:0.358, Lr:3.68E-05
Epoch:25, Train_acc:84.7%, Train_loss:0.353, Test_acc:84.6%, Test_loss:0.357, Lr:3.68E-05
Epoch:26, Train_acc:84.8%, Train_loss:0.353, Test_acc:84.7%, Test_loss:0.354, Lr:3.38E-05
Epoch:27, Train_acc:84.7%, Train_loss:0.352, Test_acc:84.7%, Test_loss:0.353, Lr:3.38E-05
Epoch:28, Train_acc:84.8%, Train_loss:0.352, Test_acc:84.7%, Test_loss:0.354, Lr:3.11E-05
Epoch:29, Train_acc:84.8%, Train_loss:0.352, Test_acc:84.8%, Test_loss:0.354, Lr:3.11E-05
Epoch:30, Train_acc:84.9%, Train_loss:0.352, Test_acc:84.8%, Test_loss:0.353, Lr:2.86E-05
Epoch:31, Train_acc:84.9%, Train_loss:0.351, Test_acc:84.7%, Test_loss:0.356, Lr:2.86E-05
Epoch:32, Train_acc:84.9%, Train_loss:0.351, Test_acc:84.6%, Test_loss:0.354, Lr:2.63E-05
Epoch:33, Train_acc:84.8%, Train_loss:0.350, Test_acc:84.8%, Test_loss:0.352, Lr:2.63E-05
Epoch:34, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.7%, Test_loss:0.354, Lr:2.42E-05
Epoch:35, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.8%, Test_loss:0.352, Lr:2.42E-05
Epoch:36, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.6%, Test_loss:0.354, Lr:2.23E-05
Epoch:37, Train_acc:84.9%, Train_loss:0.349, Test_acc:84.8%, Test_loss:0.353, Lr:2.23E-05
Epoch:38, Train_acc:84.9%, Train_loss:0.349, Test_acc:84.6%, Test_loss:0.356, Lr:2.05E-05
Epoch:39, Train_acc:85.0%, Train_loss:0.348, Test_acc:85.0%, Test_loss:0.351, Lr:2.05E-05
Epoch:40, Train_acc:85.0%, Train_loss:0.349, Test_acc:84.8%, Test_loss:0.351, Lr:1.89E-05
Epoch:41, Train_acc:85.0%, Train_loss:0.348, Test_acc:84.8%, Test_loss:0.351, Lr:1.89E-05
Epoch:42, Train_acc:85.0%, Train_loss:0.348, Test_acc:84.9%, Test_loss:0.351, Lr:1.74E-05
Epoch:43, Train_acc:85.0%, Train_loss:0.347, Test_acc:85.0%, Test_loss:0.350, Lr:1.74E-05
Epoch:44, Train_acc:85.0%, Train_loss:0.347, Test_acc:85.0%, Test_loss:0.351, Lr:1.60E-05
Epoch:45, Train_acc:85.1%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.350, Lr:1.60E-05
Epoch:46, Train_acc:85.1%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.350, Lr:1.47E-05
Epoch:47, Train_acc:85.0%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.351, Lr:1.47E-05
Epoch:48, Train_acc:85.0%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.35E-05
Epoch:49, Train_acc:85.1%, Train_loss:0.346, Test_acc:85.0%, Test_loss:0.349, Lr:1.35E-05
Epoch:50, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.24E-05
Epoch:51, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.349, Lr:1.24E-05
Epoch:52, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.14E-05
Epoch:53, Train_acc:85.1%, Train_loss:0.345, Test_acc:84.9%, Test_loss:0.349, Lr:1.14E-05
Epoch:54, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.349, Lr:1.05E-05
Epoch:55, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.349, Lr:1.05E-05
Epoch:56, Train_acc:85.1%, Train_loss:0.345, Test_acc:84.8%, Test_loss:0.350, Lr:9.68E-06
Epoch:57, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.348, Lr:9.68E-06
Epoch:58, Train_acc:85.1%, Train_loss:0.344, Test_acc:84.8%, Test_loss:0.350, Lr:8.91E-06
Epoch:59, Train_acc:85.2%, Train_loss:0.344, Test_acc:85.0%, Test_loss:0.348, Lr:8.91E-06
Epoch:60, Train_acc:85.2%, Train_loss:0.344, Test_acc:84.9%, Test_loss:0.349, Lr:8.20E-06
Done. Best test acc: 0.8500206242265915
11. 过程可视化
epochs_range = range(epochs)
plt.figure(figsize=(12, 5))
# 准确率曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc_list, label='Training Accuracy')
plt.plot(epochs_range, test_acc_list, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 损失曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss_list, label='Training Loss')
plt.plot(epochs_range, test_loss_list, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()