基于 BiLSTM+Attention 实现降雨预测多变量时序分类——明日是否降雨

news2024/12/24 0:10:13

降雨预测
前言

系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、自然语言处理、深度强化学习、大型语言模型和迁移学习。

降雨预测作为气象学和水文学领域的重要研究课题,‌对于农业、‌城市规划、‌水资源管理及灾害预警等多个方面都具有极其重要的实际应用价值。‌传统的降雨预测方法主要依赖于数值天气预报模型和统计学方法,‌但这些方法在处理复杂非线性关系和时序依赖性时存在一定的局限性。‌近年来,‌随着深度学习技术的快速发展,‌尤其是循环神经网络(‌RNN)‌及其变体长短时记忆网络(‌LSTM)‌和双向长短时记忆网络(‌BiLSTM)‌在时序数据处理方面展现出了强大的能力。‌

BiLSTM网络通过引入双向机制,‌能够同时考虑过去和未来的信息,‌对于时序数据的特征提取尤为有效。‌而注意力机制(‌Attention)‌的加入,‌则进一步增强了模型对关键信息的聚焦能力,‌提高了预测的准确性。‌本研究旨在探索将BiLSTM与Attention机制相结合,‌应用于多变量时序分类任务中,‌具体针对“明日是否降雨”这一实际问题进行建模和预测。‌

本研究的主要贡献在于:‌首先,‌构建了一个基于BiLSTM+Attention的深度学习模型,‌该模型能够有效处理多变量时序数据,‌捕捉到降雨预测中的关键时序特征和变量间的复杂交互作用;‌其次,‌通过实际的气象数据集进行训练和验证,‌评估模型在降雨预测任务上的性能和泛化能力;‌最后,‌探讨不同模型参数和结构设计对预测结果的影响,‌为进一步优化降雨预测模型提供理论和实践依据。‌

综上所述,‌本研究期望通过深度学习方法的应用,‌为降雨预测领域带来新的视角和解决方案,‌提升预测的准确性和时效性,‌从而更好地服务于社会经济的可持续发展和人民生活的安全保障。‌

目录

  • 1. 数据集介绍
  • 2. 数据可视化
    • 2.1 检查数据是否缺失
    • 2.2 检查数据是否平衡
  • 3. 数据预处理
    • 3.1 数据清理——填补缺失值
      • 3.1.1 分类变量
      • 3.1.2 数值变量
    • 3.2 异常检测
      • 3.2.1 数值变量异常检测
      • 3.2.2 异常值离群值处理
  • 4. 特征工程
    • 4.1 Label编码和One-hot编码
    • 4.2 特征缩放(归一化)
    • 4.3 构建时间序列数据
    • 4.4 数据集过采样 SMOTE
    • 4.5 数据集划分
    • 4.6 数据集张量
  • 5. 构建时序模型(TSC)
    • 5.1 构建BiLSTM+Attention模型
    • 5.2 定义模型、损失函数与优化器
    • 5.3 模型概要
  • 6. 模型训练与可视化
    • 6.1 定义训练与评估函数
    • 6.2 绘制损失与准确率曲线
  • 7. 模型评估与可视化
    • 7.1 构建预测函数
    • 7.2 混淆矩阵
    • 7.3 ROC_AUC曲线
    • 7.4 分类报告

1. 数据集介绍

该数据集包括澳大利亚许多地点约 10 年的每日天气观测数据。RainTomorrow 是要预测的目标变量。它回答了一个关键问题:第二天会下雨吗? 是或否)。如果当天的降雨量达到或超过 1 毫米,则此列标记为 “是”。下载🔗

首先让我们导入必要的库和数据集

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, auc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchinfo import summary

np.random.seed(0)
data = pd.read_csv("weatherAUS.csv")
print(data.head().T)
                        0           1           2           3           4
Date           2008-12-01  2008-12-02  2008-12-03  2008-12-04  2008-12-05
Location           Albury      Albury      Albury      Albury      Albury
MinTemp              13.4         7.4        12.9         9.2        17.5
MaxTemp              22.9        25.1        25.7        28.0        32.3
Rainfall              0.6         0.0         0.0         0.0         1.0
Evaporation           NaN         NaN         NaN         NaN         NaN
Sunshine              NaN         NaN         NaN         NaN         NaN
WindGustDir             W         WNW         WSW          NE           W
WindGustSpeed        44.0        44.0        46.0        24.0        41.0
WindDir9am              W         NNW           W          SE         ENE
WindDir3pm            WNW         WSW         WSW           E          NW
WindSpeed9am         20.0         4.0        19.0        11.0         7.0
WindSpeed3pm         24.0        22.0        26.0         9.0        20.0
Humidity9am          71.0        44.0        38.0        45.0        82.0
Humidity3pm          22.0        25.0        30.0        16.0        33.0
Pressure9am        1007.7      1010.6      1007.6      1017.6      1010.8
Pressure3pm        1007.1      1007.8      1008.7      1012.8      1006.0
Cloud9am              8.0         NaN         NaN         NaN         7.0
Cloud3pm              NaN         NaN         2.0         NaN         8.0
Temp9am              16.9        17.2        21.0        18.1        17.8
Temp3pm              21.8        24.3        23.2        26.5        29.7
RainToday              No          No          No          No          No
RainTomorrow           No          No          No          No          No

该数据集包含澳大利亚各地约 10 年的每日天气观测数据。观测数据来自众多气象站。在本项目中,我将利用这些数据预测第二天是否会下雨。包括目标变量 “RainTomorrow ”在内的 23 个属性表明第二天是否会下雨。

2. 数据可视化

2.1 检查数据是否缺失

.info()方法打印有关DataFrame 的信息,包括索引 dtype 和列、非 null 值以及内存使用情况。

data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 23 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Date           145460 non-null  object 
 1   Location       145460 non-null  object 
 2   MinTemp        143975 non-null  float64
 3   MaxTemp        144199 non-null  float64
 4   Rainfall       142199 non-null  float64
 5   Evaporation    82670 non-null   float64
 6   Sunshine       75625 non-null   float64
 7   WindGustDir    135134 non-null  object 
 8   WindGustSpeed  135197 non-null  float64
 9   WindDir9am     134894 non-null  object 
 10  WindDir3pm     141232 non-null  object 
 11  WindSpeed9am   143693 non-null  float64
 12  WindSpeed3pm   142398 non-null  float64
 13  Humidity9am    142806 non-null  float64
 14  Humidity3pm    140953 non-null  float64
 15  Pressure9am    130395 non-null  float64
 16  Pressure3pm    130432 non-null  float64
 17  Cloud9am       89572 non-null   float64
 18  Cloud3pm       86102 non-null   float64
 19  Temp9am        143693 non-null  float64
 20  Temp3pm        141851 non-null  float64
 21  RainToday      142199 non-null  object 
 22  RainTomorrow   142193 non-null  object 
dtypes: float64(16), object(7)
memory usage: 25.5+ MB
sns.heatmap(data.isnull(), cbar=False, cmap='PuBu')

请添加图片描述

我们可以很明显的观察到数据集中有缺失值,数据集中包括数值和分类值

2.2 检查数据是否平衡

接下来,我们将检查数据集是不平衡还是平衡的。如果数据集是不平衡的,我们就需要对大多数数据进行降采样或对少数数据进行超采样,以达到平衡。

proportion = data.RainTomorrow.value_counts(normalize = True)

plt.style.use('bmh')
plt.figure(figsize=(8,5))
plt.bar(proportion.index, proportion.values, color=['teal', 'grey'])
plt.xlabel('RainTomorrow')
plt.ylabel('Proportion')
plt.title('Proportion of RainTomorrow')
plt.xticks(rotation=45, ha="right")
plt.show()

请添加图片描述

3. 数据预处理

现在,我将把日期解析为时间数据类型

print(type(data['RainTomorrow'].iloc[0]),type(data['Date'].iloc[0]))

# Let's convert the data type of timestamp column to datatime format
data['Date'] = pd.to_datetime(data['Date'])
print(type(data['RainTomorrow'].iloc[0]),type(data['Date'].iloc[0]))

print(data.shape)
<class 'str'> <class 'str'>
<class 'str'> <class 'pandas._libs.tslibs.timestamps.Timestamp'>
(145460, 23)

3.1 数据清理——填补缺失值

3.1.1 分类变量

用列值的众数 mode 填补缺失值

# Selecting columns of categorical variables
object_columns = data.select_dtypes(include=['object']).columns.tolist()

# Missing values in categorical variables
data[object_columns].isnull().sum()
Location            0
WindGustDir     10326
WindDir9am      10566
WindDir3pm       4228
RainToday        3261
RainTomorrow     3267
dtype: int64
# Filling missing values with mode of the column in value
for col in object_columns:
    data.fillna({col: data[col].mode()[0]}, inplace=True)

# Counting missing values
data[object_columns].isnull().sum()
Location        0
WindGustDir     0
WindDir9am      0
WindDir3pm      0
RainToday       0
RainTomorrow    0
dtype: int64

3.1.2 数值变量

用列值的中位数 median 填补缺失值

# Selecting columns of neumeric variables
neumeric_columns = data.select_dtypes(include=['float64']).columns.tolist()

# Missing values in numeric variables
data[neumeric_columns].isnull().sum()
MinTemp           1485
MaxTemp           1261
Rainfall          3261
Evaporation      62790
Sunshine         69835
WindGustSpeed    10263
WindSpeed9am      1767
WindSpeed3pm      3062
Humidity9am       2654
Humidity3pm       4507
Pressure9am      15065
Pressure3pm      15028
Cloud9am         55888
Cloud3pm         59358
Temp9am           1767
Temp3pm           3609
dtype: int64
# Filling missing values with median of the column in value
for col in neumeric_columns:
    data.fillna({col: data[col].median()}, inplace=True)

# Counting missing values
data[neumeric_columns].isnull().sum()
MinTemp          0
MaxTemp          0
Rainfall         0
Evaporation      0
Sunshine         0
WindGustSpeed    0
WindSpeed9am     0
WindSpeed3pm     0
Humidity9am      0
Humidity3pm      0
Pressure9am      0
Pressure3pm      0
Cloud9am         0
Cloud3pm         0
Temp9am          0
Temp3pm          0
dtype: int64

3.2 异常检测

3.2.1 数值变量异常检测

在统计学和数据科学中,‌识别和处理异常值(‌outliers)‌是一个重要的步骤,‌因为它们可能会对分析产生重大影响。‌异常值是指那些与其他数据点显著不同的观测值。‌使用四分位数范围(‌IQR)‌来检测异常值是一种常见且有效的方法。‌

k = 1.5
# Initialize the figure with a logarithmic x axis
fig, axes = plt.subplots(nrows=len(neumeric_columns), ncols=1, figsize=(20, 20))
axes = axes.flatten()

for i, feature in enumerate(neumeric_columns):
    # Plot the orbital period with horizontal boxes
    sns.boxenplot(x=data[feature], ax=axes[i], color= 'thistle', linecolor= 'grey', orient='h')
    
    # Detecting outliers with IQR
    Q1 = data[feature].quantile(0.25)
    Q3 = data[feature].quantile(0.75)
    IQR = Q3 - Q1
    lower = Q1 - k * IQR
    upper = Q3 + k * IQR
    axes[i].axvline(x=lower, color='r', linestyle='--', label='lower')
    axes[i].axvline(x=upper, color='b', linestyle='--', label='upper')
    
    # Tweak the visual presentation
    #sns.despine(ax=axes[i],trim=True, left=True)
    axes[i].text(lower, 0.8, f'lower = {lower}',  color='red', fontsize=12)
    axes[i].text(upper, 0.8, f'upper = {upper}',  color='blue', fontsize=12)
    axes[i].set_title(f'{feature}')
    axes[i].set_xlabel('')
    
plt.tight_layout()
plt.show()

在这里插入图片描述

3.2.2 异常值离群值处理

接下来我们将使用 IQR 方法检测并替换异常值

def Handle_outlier(data, column):   
    # 检测并替换异常值
    # Detecting outliers with IQR
    Q1 = data[column].quantile(0.25)
    Q3 = data[column].quantile(0.75)
    IQR = Q3 - Q1
    lower = Q1 - k * IQR
    upper = Q3 + k * IQR
    # 使用 np.where 和 np.logical_or 处理异常值
    data[column] = np.where(
        np.logical_or(data[column] < lower, data[column] > upper),
        np.select([data[column] < lower, data[column] > upper], [lower, upper]), data[column])
    
    # data = data[(data[column] >= lower) & (data[column] <= upper)]
    return data
data = Handle_outlier(data, column = 'MinTemp')
data = Handle_outlier(data, column = 'MaxTemp')
data = Handle_outlier(data, column = 'Rainfall')
data = Handle_outlier(data, column = 'Evaporation')
data = Handle_outlier(data, column = 'Sunshine')
data = Handle_outlier(data, column = 'WindGustSpeed')
data = Handle_outlier(data, column = 'WindSpeed9am')
data = Handle_outlier(data, column = 'WindSpeed3pm')
data = Handle_outlier(data, column = 'Humidity9am')
data = Handle_outlier(data, column = 'Humidity3pm')
data = Handle_outlier(data, column = 'Pressure9am')
data = Handle_outlier(data, column = 'Pressure3pm')
data = Handle_outlier(data, column = 'Cloud9am')
data = Handle_outlier(data, column = 'Cloud3pm')
data = Handle_outlier(data, column = 'Temp9am')
data = Handle_outlier(data, column = 'Temp3pm')

4. 特征工程

4.1 Label编码和One-hot编码

对分类变量进行标签编码,离散值特征进行One-hot编码

# Apply label encoder to RainToday, RainTomorrow
le = LabelEncoder()
data['RainToday'] = le.fit_transform(data['RainToday'])
data['RainTomorrow'] = le.fit_transform(data['RainTomorrow'])
    
ohe = OneHotEncoder() # 离散值特征One-hot编码
encoded = ohe.fit_transform(data[['Location',
                                  'WindGustDir',
                                  'WindDir9am',
                                  'WindDir3pm'
                                 ]])
encoded_data = pd.DataFrame(encoded.toarray(),columns = ohe.get_feature_names_out())
data = pd.concat([data,encoded_data],axis=1)
data = data.drop(['Location',
                  'WindGustDir',
                  'WindDir9am', 
                  'WindDir3pm'], axis =1)
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Columns: 116 entries, Date to WindDir3pm_WSW
dtypes: datetime64[ns](1), float64(113), int32(2)
memory usage: 127.6 MB
None

现在让我们使用 .corr() 函数来看看数据之间的相关性:

correlation = data.corr()
print(correlation["RainTomorrow"].sort_values(ascending=False))
RainTomorrow    1.000000
Humidity3pm     0.433167
Rainfall        0.323354
RainToday       0.305744
Cloud3pm        0.291963
                  ...   
MaxTemp        -0.156313
Temp3pm        -0.187675
Pressure3pm    -0.209378
Pressure9am    -0.228542
Sunshine       -0.288223
Name: RainTomorrow, Length: 116, dtype: float64

4.2 特征缩放(归一化)

StandardScaler()函数将数据的特征值转换为符合正态分布的形式,它将数据缩放到均值为0,‌标准差为1的区间‌。在机器学习中,StandardScaler()函数常用于不同尺度特征数据的标准化,以提高模型的泛化能力。

# dividing the future and the target from the dataset 
features = data.drop(['Date', 'RainTomorrow'], axis=1)
target = data['RainTomorrow'].values.reshape(-1, 1)
# 创建 StandardScaler实例,对特征进行拟合和变换,生成NumPy数组
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
print(features_scaled)

4.3 构建时间序列数据

time_steps = 10
X_list = []
y_list = []

for i in range(len(features_scaled) - time_steps):
    X_list.append(features_scaled[i:i+time_steps])
    y_list.append(target[i+time_steps])

X = np.array(X_list) # [samples, time_steps, num_features]
y = np.array(y_list) # [target]

4.4 数据集过采样 SMOTE

SMOTE (synthetic minority oversampling technique) 合成少数群体超采样技术是解决不平衡问题最常用的超采样方法之一。它的目的是通过复制少数类实例来随机增加少数类实例,从而平衡类的分布。SMOTE 在现有的少数类实例之间合成新的少数类实例。它通过线性插值为少数类生成虚拟训练记录。这些合成训练记录是通过为少数群体中的每个实例随机选择一个或多个 k 近邻来生成的。过采样过程结束后,数据将被重建,并可对处理后的数据应用多个分类模型。

samples, time_steps, num_features = X.shape

# 将 X重塑为二维数组,‌因为SMOTE期望二维输入
X_reshaped = X.reshape(samples, time_steps * num_features)

# Oversampling
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_reshaped, y)

# 这里我们将 X_resampled 重新塑形为原始的三维形状
X_resampled = X_resampled.reshape(-1, time_steps, num_features)

4.5 数据集划分

X_train, X_valid,\
    y_train, y_valid = train_test_split(X_resampled, y_resampled, 
                                        test_size=0.25, 
                                        random_state=12345)
print(X_train.shape, X_valid.shape, y_train.shape, y_valid.shape)

4.6 数据集张量

# 将 NumPy数组转换为 tensor张量
X_train_tensor = torch.from_numpy(X_train).type(torch.Tensor)
X_valid_tensor = torch.from_numpy(X_valid).type(torch.Tensor)
y_train_tensor = torch.from_numpy(y_train).type(torch.Tensor).view(-1, 1)
y_valid_tensor = torch.from_numpy(y_valid).type(torch.Tensor).view(-1, 1)

print(X_train_tensor.shape, X_valid_tensor.shape, y_train_tensor.shape, y_valid_tensor.shape)
torch.Size([170361, 10, 114]) torch.Size([56787, 10, 114]) torch.Size([170361, 1]) torch.Size([56787, 1])

.type(torch.Tensor) 明确将该张量的数据类型指定为 torch.Tensor, 而.type(torch.long) 明确将标签的张量数据类型指定为长整型torch.long。这通常用于表示整数类型的标签

class DataHandler(Dataset):
    def __init__(self, X_train_tensor, y_train_tensor, X_valid_tensor, y_valid_tensor):
        self.X_train_tensor = X_train_tensor
        self.y_train_tensor = y_train_tensor
        self.X_valid_tensor = X_valid_tensor
        self.y_valid_tensor = y_valid_tensor
        
    def __len__(self):
        return len(self.X_train_tensor)

    def __getitem__(self, idx):
        sample = self.X_train_tensor[idx]
        labels = self.y_train_tensor[idx]
        return sample, labels
        
    def train_loader(self):
        train_dataset = TensorDataset(self.X_train_tensor, self.y_train_tensor)
        return DataLoader(train_dataset, batch_size=32, shuffle=True)

    def valid_loader(self):
        valid_dataset = TensorDataset(self.X_valid_tensor, self.y_valid_tensor)
        return DataLoader(valid_dataset, batch_size=32, shuffle=False)

在上述代码中,定义了一个名为 TSCDataset 的类,它继承自 torch.utils.data.Dataset
__init__ 方法用于接收数据和标签。
__len__ 方法返回数据集的长度。
__getitem__ 方法根据给定的索引 idx 返回相应的数据样本和标签。

data_handler = DataHandler(X_train_tensor, y_train_tensor, X_valid_tensor, y_valid_tensor)
train_loader = data_handler.train_loader()
valid_loader = data_handler.valid_loader()

5. 构建时序模型(TSC)

5.1 构建BiLSTM+Attention模型

该组合模型能够综合利用 LSTM 对序列数据的长期依赖处理能力以及注意力机制对不同特征重要性的动态关注能力,适用于处理具有复杂时空特征的序列数据

class BiLSTM_Attention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(BiLSTM_Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.5)
        # Attention 机制
        self.attention_weights = nn.Parameter(torch.randn(hidden_dim * 2))

    def attention(self, outputs):
        # 计算注意力权重
        attention_weights = torch.nn.Softmax(dim=1)(torch.matmul(outputs, self.attention_weights))
        # 应用注意力权重
        attended_output = torch.sum(outputs * attention_weights.unsqueeze(-1), dim=1)
        return attended_output

    def forward(self, x):
        # 初始化隐藏状态和单元状态
        h0 = torch.zeros(self.num_layers * 2, x.size(1), self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(1), self.hidden_dim).to(x.device)

        # 前向传播通过 BiLSTM
        out, _ = self.lstm(x, (h0, c0))

        # 应用注意力机制
        attn_out = self.attention(out)
        attn_out = self.dropout(attn_out)
        # 全连接层
        output = self.fc(attn_out)
        return output

5.2 定义模型、损失函数与优化器

model = BiLSTM_Attention(input_dim = 114, hidden_dim = 8, num_layers = 1, output_dim = 1)
criterion = torch.nn.BCEWithLogitsLoss() # 定义二进制交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 定义优化器

5.3 模型概要

summary(model, (32, time_steps, num_features)) # batch_size, seq_len(time_steps), input_size(in_channels)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
BiLSTM_Attention                         [32, 1]                   16
├─LSTM: 1-1                              [32, 10, 16]              7,936
├─Dropout: 1-2                           [32, 16]                  --
├─Linear: 1-3                            [32, 1]                   17
==========================================================================================
Total params: 7,969
Trainable params: 7,969
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 2.54
==========================================================================================
Input size (MB): 0.15
Forward/backward pass size (MB): 0.04
Params size (MB): 0.03
Estimated Total Size (MB): 0.22
==========================================================================================

6. 模型训练与可视化

6.1 定义训练与评估函数

定义binary_accuracy函数来衡量模型性能

def binary_accuracy(outputs, labels):
    # 通过 sigmoid 函数将输出值映射到 [0, 1] 区间
    outputs = torch.sigmoid(outputs)
    # 将输出值与 0.5 比较,得到预测的类别(0 或 1)
    predicted = (outputs > 0.5).float()
    # 计算预测正确的数量
    correct = (predicted == labels).float().sum()
    # 计算总样本数量
    total = labels.size(0)
    # 计算准确率
    accuracy = correct / total
    return accuracy

上述代码,定义了一个名为 binary_accuracy 的函数,用于计算二分类任务中的准确率。它接收模型的输出结果 outputs和真实标签 labels 作为参数,并返回计算得到的准确率值。

def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0

    model.train()  # 确保模型处于训练模式
    for batch in iterator:
        optimizer.zero_grad()  # 清空梯度
        inputs, labels = batch  # 获取输入和标签
        outputs = model(inputs)  # 前向传播
        
        # 计算损失和准确率
        loss = criterion(outputs, labels)
        acc = binary_accuracy(outputs, labels)

        loss.backward()
        optimizer.step()

        # 累积损失和准确率
        epoch_loss += loss.item()
        epoch_acc += acc

    # 计算平均损失和准确率
    average_loss = epoch_loss / len(iterator)
    average_acc = epoch_acc / len(iterator)

    return average_loss, average_acc

上述代码定义了一个名为 train 的函数,用于训练给定的模型。它接收模型、数据迭代器、优化器和损失函数作为参数,并返回训练过程中的平均损失和平均准确率。

def evaluate(model, iterator, criterion):
    epoch_loss = 0
    epoch_acc = 0

    model.eval()  # 将模型设置为评估模式,例如关闭 Dropout 等
    with torch.no_grad():  # 不需要计算梯度
        for batch in iterator:
            inputs, labels = batch
            outputs = model(inputs)  # 前向传播

            # 计算损失和准确率
            loss = criterion(outputs, labels)
            acc = binary_accuracy(outputs, labels)

            # 累计损失和准确率
            epoch_loss += loss.item()
            epoch_acc += acc

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

上述代码定义了一个名为 evaluate 的函数,用于评估给定模型在给定数据迭代器上的性能。它接收模型、数据迭代器和损失函数作为参数,并返回评估过程中的平均损失和平均准确率。这个函数通常在模型训练的过程中定期被调用,以监控模型在验证集或测试集上的性能。通过评估模型的性能,可以了解模型的泛化能力和训练的进展情况。

best_acc = 0
epoch = 100
train_losses = []
valid_losses = []
train_accs = []
valid_accs = []

for epoch in range(epoch):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    train_accs.append(train_acc)
    valid_accs.append(valid_acc)
    
    print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc * 100:.2f}%, Val. Loss: {valid_loss:.3f}, Val. Acc: {valid_acc * 100:.2f}%')
    
    if best_acc <= valid_acc:
        best_acc = valid_acc
        pth = model.state_dict()
Epoch: 01, Train Loss: 0.633, Train Acc: 63.90%, Val. Loss: 0.594, Val. Acc: 68.16%
Epoch: 02, Train Loss: 0.590, Train Acc: 68.66%, Val. Loss: 0.548, Val. Acc: 72.32%
Epoch: 03, Train Loss: 0.543, Train Acc: 72.48%, Val. Loss: 0.483, Val. Acc: 77.23%
Epoch: 04, Train Loss: 0.489, Train Acc: 75.86%, Val. Loss: 0.424, Val. Acc: 80.48%
Epoch: 05, Train Loss: 0.452, Train Acc: 78.21%, Val. Loss: 0.397, Val. Acc: 81.64%
******
Epoch: 96, Train Loss: 0.345, Train Acc: 83.99%, Val. Loss: 0.342, Val. Acc: 84.17%
Epoch: 97, Train Loss: 0.346, Train Acc: 83.95%, Val. Loss: 0.339, Val. Acc: 84.39%
Epoch: 98, Train Loss: 0.346, Train Acc: 83.91%, Val. Loss: 0.340, Val. Acc: 84.15%
Epoch: 99, Train Loss: 0.345, Train Acc: 84.03%, Val. Loss: 0.341, Val. Acc: 84.23%
Epoch: 100, Train Loss: 0.345, Train Acc: 83.91%, Val. Loss: 0.340, Val. Acc: 84.30%

6.2 绘制损失与准确率曲线

# 绘制损失图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss')
plt.legend()
plt.grid(True)

# 绘制准确率图
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(valid_accs, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Train and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.show()

损失与准确率曲线

7. 模型评估与可视化

7.1 构建预测函数

定义预测函数prediction 方便调用

# 定义 prediction函数
def prediction(model, valid_loader): 
    all_labels = []
    all_predictions = []
    all_predictions_prob = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in valid_loader:
            outputs = model(inputs)
            predictions_prob = torch.sigmoid(outputs)
            predicted = (predictions_prob > 0.5).float()
            all_labels.extend(labels.numpy())
            all_predictions.extend(predicted.numpy())
            all_predictions_prob.extend(predictions_prob.numpy())
    return all_labels, all_predictions, all_predictions_prob

上述代码定义了一个名为 prediction 的函数,用于对给定的模型在验证数据加载器(valid_loader)上进行预测,并返回真实标签、预测的类别以及预测的概率。这个函数通常在模型训练完成后,用于对新的数据进行预测。通过收集所有的预测结果,可以进一步分析模型的性能,例如计算准确率、绘制混淆矩阵等。它也可以用于实际应用中,对未知数据进行预测并做出决策。

# 预测结果
labels, predictions, predictions_prob = prediction(model, valid_loader)

7.2 混淆矩阵

def plot_confusion_matrix(labels, predictions, classes):
    cm = confusion_matrix(labels, predictions)
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

上述代码定义一个名为 plot_confusion_matrix 的函数,用于绘制给定真实标签和预测结果的混淆矩阵。混淆矩阵是一种用于评估分类模型性能的可视化工具,它展示了模型在不同类别上的预测准确性。

classes = ['Class 0', 'Class 1']

绘制混淆矩阵

plot_confusion_matrix(labels, predictions, classes)

混淆矩阵

7.3 ROC_AUC曲线

def plot_roc_curve(labels, predictions_prob):
    fpr, tpr, _ = roc_curve(labels, predictions_prob)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.legend(loc="lower right")
    plt.show()
# 绘制 ROC曲线
plot_roc_curve(labels, predictions_prob)

请添加图片描述

7.4 分类报告

from sklearn.metrics import classification_report
print(classification_report(labels, predictions))
              precision    recall  f1-score   support

         0.0       0.77      0.97      0.86     28186
         1.0       0.96      0.72      0.82     28601

    accuracy                           0.84     56787
   macro avg       0.87      0.84      0.84     56787
weighted avg       0.87      0.84      0.84     56787

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2108718.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp解决页面跳转时,含有base64的数据丢失问题

由于url长度的限制&#xff0c;base64数据过长可能导致数据丢失&#xff0c;以至于base64图片显示不出来或者格式错误。 解决办法&#xff1a; 跳转前进行base64编码&#xff1a;encodeURIComponent 接收数据时&#xff0c;对base64进行解码&#xff1a;decodeURIComponent

【2024数模国赛赛题思路公开】国赛D题思路丨附可运行代码丨无偿自提

2024年国赛D题解题思路 问题一 【题目】 投射一枚深弹&#xff0c;潜艇中心位置的深度定位没有误差&#xff0c;两个水平坐标定位均服从正态分布。分析投弹最大命中概率与投弹落点平面坐标及定深引信引爆深度之间的关系&#xff0c;并给出使得投弹命中概率最大的投弹方案&…

One-Shot Hierarchical Imitation Learning of Compound Visuomotor Tasks

发表时间&#xff1a;25 Oct 2018 论文链接&#xff1a;https://readpaper.com/pdf-annotate/note?pdfId4500198746683498497&noteId2453372035670907392 作者单位&#xff1a;Berkeley AI Research Motivation&#xff1a;我们考虑从执行任务的人类的单个视频中学习真…

ITK-重采样

作者&#xff1a;翟天保Steven 版权声明&#xff1a;著作权归作者所有&#xff0c;商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处 什么是重采样 重采样&#xff08;Resampling&#xff09; 是一种用于图像处理的技术&#xff0c;主要应用于对图像进行尺寸调整、…

【专项刷题】— 栈

1、删除字符串中的所有相邻重复项 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 使用栈进行操作&#xff0c;每次入栈的时候和栈顶元素进行比对&#xff0c;如果相同的话就弹出栈顶元素也可以用数组来模拟栈进行操作代码&#xff1a; public String removeDuplica…

基于人工智能的交通标志识别系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 交通标志识别系统是自动驾驶和智能交通的重要组成部分&#xff0c;能够帮助车辆自动识别路边的交通标志并作出相应的决策。通过使用深…

C语言-数据结构 无向图普里姆Prim算法(邻接矩阵存储)

Prim算法使用了贪心的思想&#xff0c;在算法中使用了两个数组&#xff0c;这两个数组会非常巧妙的操作整个算法的灵魂过程 lowcost的功能&#xff1a; 1.帮助算法寻找到当前距离已完成的最小生成树集合的最小的边长&#xff08;找到新边&#xff09; 2.在整个过程中记录新结…

分拣机介绍及解决方案细节

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》人俱乐部 完整版文件和更多学习资料&#xff0c;请球友到知识星球【智能仓储物流技术研习社】自行下载。 这份文件是关于交叉带式分拣机的介绍及解…

openSSL 如何降版本

文章目录 前言openSSL 如何降版本1. 卸载2. 安装新的openssl版本3. 验证 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差&#xff0c;实在白嫖的话&…

RT-Thread 使用HTTP固件下载方式进行OTA远程升级

参考资料:RT-T官网资料如下链接所示 STM32通用Bootloader (rt-thread.org) 1.app程序env配置过程 参考上述资料中"制作 app 固件"章节&#xff0c;分区大小根据自己设备而定&#xff0c;以下是我以407VET6为例设置的fal分区 notes:上述分区是由片内flash(on-chip)…

机械革命imini Pro820迷你主机评测和拆解,8845H小主机使用政府补贴仅需两千三

机械革命imini Pro820迷你主机评测和拆解&#xff0c;8845H小主机使用政府补贴仅需两千三。 最近上线了家电补贴相关的活动&#xff0c;最高可以补贴20%&#xff0c;然后就看到了这款mini主机感觉很划算就下单了&#xff0c;用来替换我旧的N5095小主机&#xff0c;当服务器用。…

电子技术基础

目录 二极管 二极管的概念二极管的整流 二极管的防反接 二极管的钳位稳压二极管 三极管 NPN型三极管PNP型三极管三极管的三种状态三极管三个极之间电流的关系 放大电路 三极管共射极放大电路分压式偏置电路静态工作点多级放大功率放大电路 运算放大器 同相比例放大器反相…

旅行商问题 | Matlab基于混合粒子群算法GA-PSO的旅行商问题TSP

目录 效果一览基本介绍建模步骤程序设计参考资料 效果一览 基本介绍 混合粒子群算法GA-PSO是一种结合了遗传算法&#xff08;Genetic Algorithm, GA&#xff09;和粒子群优化算法&#xff08;Particle Swarm Optimization, PSO&#xff09;的优化算法。在解决旅行商问题&#…

「Python数据分析」Pandas进阶,使用groupby分组聚合数据(三)

​在实际数据分析和处理过程中&#xff0c;我们可能需要灵活对分组数据进行聚合操作。这个时候&#xff0c;我们就需要用到用户自定义函数&#xff08;User-Defined Functions&#xff0c;UDFs&#xff09;。 使用用户自定义函数进行聚合 使用用户自定义函数聚合时的性能&…

联想泄露显示本月推出更便宜的Copilot Plus电脑

联想似乎准备推出新的更实惠的 Copilot Plus 电脑。可靠的爆料者Evan Blass发布了一份来自联想的新闻稿&#xff0c;详细介绍了将在本周晚些时候的IFA展会上宣布的各种Copilot Plus电脑&#xff0c;其中包括两款采用尚未公布的8核高通骁龙X Plus芯片的电脑。 这些新的高通芯片…

Qt 创建一个json数组对象写入文档并从文档读出q

void createJsonArray() { // 创建一个JSON数组 QJsonArray jsonArray; // 创建一些JSON对象并添加到数组中 for (int i 0; i < 3; i) { QJsonObject jsonObject; jsonObject["key" QString::number(i)] "value" QStri…

原点安全荣获“AutoSec Awards 安全之星”优秀汽车数据安全合规方案奖

9月3日&#xff0c;「AutoSec 2024第八届中国汽车网络安全周暨第五届智能汽车数据安全展」在上海盛大开幕。本届大会由谈思实验室和谈思汽车主办、上海市车联网协会联合主办&#xff0c;以汽车“网络数据安全、软件安全、功能安全”为主题&#xff0c;汇聚了国内外的技术专家、…

Meta关闭Spark AR平台:未来规划与影响分析

Meta宣布将关闭其移动AR创作平台Spark AR&#xff0c;这一消息在业界引起了广泛关注。尽管Snap和TikTok在AR滤镜领域取得了巨大成功&#xff0c;但Meta却选择了另一条发展道路。本文将探讨这一决策背后的可能原因及其对未来的影响。 关闭Spark AR平台的背后 硬件为主&#xff…

PyTorch 创建数据集

图片数据和标签数据准备 1.本文所用图片数据在同级文件夹中 ,文件路径为train/’ 2.标签数据在同级文件&#xff0c;文件路径为train.csv 3。将标签数据提取 train_csvpd.read_csv(train.csv)创建继承类 第一步&#xff0c;首先创建数据类对象 此时可以想象为单个数据单元的…

【PyTorch】基础环境如何打开

前期安装可以基于这个视频&#xff0c;本文是为了给自己存档如何打开pycharm和jupyter notebookPyTorch深度学习快速入门教程&#xff08;绝对通俗易懂&#xff01;&#xff09;【小土堆】_哔哩哔哩_bilibili Pycharm 配置 新建项目的时候选择解释器pytorch-gpu即可。 Jupyte…