深度学习:基于人工神经网络ANN的降雨预测

news2025/1/11 14:19:45

Rain

前言

系列专栏:【深度学习:算法项目实战】✨︎
本专栏涉及创建深度学习模型、处理非结构化数据以及指导复杂的模型,如卷积神经网络(CNN)、递归神经网络 (RNN),包括长短期记忆 (LSTM) 、门控循环单元 (GRU)、自动编码器 (AE)、受限玻尔兹曼机(RBM)、深度信念网络 (DBN)、生成对抗网络 (GAN)、深度强化学习(DRL)、大型语言模型(LLM)和迁移学习

降雨预测是对人类社会有重大影响的困难和不确定任务之一。及时准确的预测可以主动帮助减少人员和经济损失。本研究介绍了一组实验,其中涉及使用常见的神经网络技术创建模型,根据澳大利亚主要城市当天的天气数据预测明天是否会下雨。

目录

  • 1. 相关库和数据集
    • 1.1 相关库介绍
    • 1.2 数据集介绍
    • 1.3 数据的信息
  • 2. 数据可视化和清理
    • 2.1 目标列的计数图(检查数据是否平衡)
    • 2.2 特征属性之间的相关性
    • 2.3 将日期转换为时间序列
    • 2.4 将日和月编码为连续循环特征
    • 2.5 数据清理——填补缺失值
      • 2.5.1 分类变量
      • 2.5.2 数值变量
    • 2.6 绘制历年降雨量折线图
    • 2.7 推算历年阵风风速
  • 3. 数据预处理
    • 3.1 对分类变量进行编码标签
    • 3.2 观察比例特征
    • 3.3 观察无离群值的缩放特征
  • 4. 模型建立
    • 4.1 数据准备(拆分为训练集和测试集)
    • 4.2 模型构建
    • 4.3 绘制训练和验证损失的Loss曲线
    • 4.4 绘制训练和验证的accuracy曲线
  • 5. 模型评估
    • 5.1 混淆矩阵
    • 5.2 分类报告

1. 相关库和数据集

1.1 相关库介绍

Python 库使我们能够非常轻松地处理数据并使用一行代码执行典型和复杂的任务。

  • Pandas – 该库有助于以 2D 数组格式加载数据框,并具有多种功能,可一次性执行分析任务。
  • Numpy – Numpy 数组速度非常快,可以在很短的时间内执行大型计算。
  • Matplotlib/Seaborn – 此库用于绘制可视化效果,用于展现数据之间的相互关系。
  • Keras – 是一个由Python编写的开源人工神经网络库,可以作为 Tensorflow 的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化。
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import datetime
from sklearn.preprocessing import LabelEncoder
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from keras.layers import Dense, BatchNormalization, Dropout, LSTM
from keras.models import Sequential
from keras.utils import to_categorical
from keras.optimizers import Adam
from tensorflow.keras import regularizers
from sklearn.metrics import precision_score, recall_score, confusion_matrix, classification_report, accuracy_score, f1_score
from keras import callbacks

np.random.seed(0)

1.2 数据集介绍

该数据集包含澳大利亚各地约 10 年的每日天气观测数据。观测数据来自众多气象站。在本项目中,我将利用这些数据预测第二天是否会下雨。包括目标变量 "RainTomorrow "在内的 23 个属性表明第二天是否会下雨。

data = pd.read_csv("weatherAUS.csv")
data.head()

数据集

1.3 数据的信息

.info()方法打印有关DataFrame的信息,包括索引dtype和列、非null值以及内存使用情况。

data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 23 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Date           145460 non-null  object 
 1   Location       145460 non-null  object 
 2   MinTemp        143975 non-null  float64
 3   MaxTemp        144199 non-null  float64
 4   Rainfall       142199 non-null  float64
 5   Evaporation    82670 non-null   float64
 6   Sunshine       75625 non-null   float64
 7   WindGustDir    135134 non-null  object 
 8   WindGustSpeed  135197 non-null  float64
 9   WindDir9am     134894 non-null  object 
 10  WindDir3pm     141232 non-null  object 
 11  WindSpeed9am   143693 non-null  float64
 12  WindSpeed3pm   142398 non-null  float64
 13  Humidity9am    142806 non-null  float64
 14  Humidity3pm    140953 non-null  float64
 15  Pressure9am    130395 non-null  float64
 16  Pressure3pm    130432 non-null  float64
 17  Cloud9am       89572 non-null   float64
 18  Cloud3pm       86102 non-null   float64
 19  Temp9am        143693 non-null  float64
 20  Temp3pm        141851 non-null  float64
 21  RainToday      142199 non-null  object 
 22  RainTomorrow   142193 non-null  object 
dtypes: float64(16), object(7)
memory usage: 25.5+ MB

注意事项:

  • 数据集中存在缺失值
  • 数据集中包含数值和分类值

2. 数据可视化和清理

2.1 目标列的计数图(检查数据是否平衡)

#first of all let us evaluate the target and find out if our data is imbalanced or not
data.RainTomorrow.value_counts(normalize = True).plot(kind='bar', color= ["#C2C4E2","#EED4E5"], alpha = 0.6, rot=0)

请添加图片描述

2.2 特征属性之间的相关性

# Correlation amongst numeric attributes
corrmat = data.corr(numeric_only=True)
cmap = sns.diverging_palette(260,-10,s=50, l=75, n=6, as_cmap=True)
plt.subplots(figsize=(18,18))
sns.heatmap(corrmat,cmap= cmap,annot=True, square=True)

请添加图片描述

2.3 将日期转换为时间序列

我的目标是建立一个人工神经网络(ANN)。我将对日期进行适当的编码,也就是说,我更倾向于将月和日作为一个周期性的连续特征。因为,日期和时间本身就是循环的。为了让 ANN 模型知道某个特征是周期性的,我将其分成周期性的子部分。即年、月和日。现在,我为每个分节创建了两个新特征,分别是分节特征的正弦变换和余弦变换。

#Parsing datetime
#exploring the length of date objects
lengths = data["Date"].str.len()
lengths.value_counts()
Date
10    145460
Name: count, dtype: int64
#There don't seem to be any error in dates so parsing values into datetime
data['Date']= pd.to_datetime(data["Date"])
#Creating a collumn of year
data['year'] = data.Date.dt.year

# function to encode datetime into cyclic parameters. 
#As I am planning to use this data in a neural network I prefer the months and days in a cyclic continuous feature. 

def encode(data, col, max_val):
    data[col + '_sin'] = np.sin(2 * np.pi * data[col]/max_val)
    data[col + '_cos'] = np.cos(2 * np.pi * data[col]/max_val)
    return data

data['month'] = data.Date.dt.month
data = encode(data, 'month', 12)

data['day'] = data.Date.dt.day
data = encode(data, 'day', 31)

2.4 将日和月编码为连续循环特征

# roughly a year's span section 
section = data[:360] 
tm = section["day"].plot(color="#C2C4E2")
tm.set_title("Distribution Of Days Over Year")
tm.set_ylabel("Days In month")
tm.set_xlabel("Days In Year")
Text(0.5, 0, 'Days In Year')

每年天的分布
不出所料,数据的 "年份 "属性会重复出现。然而,在这种情况下,真正的周期性并没有以连续的方式呈现出来。将月和日拆分为正弦和余弦组合可提供周期性的连续特征。这可以作为 ANN 的输入特征。

cyclic_month = sns.scatterplot(x="month_sin",y="month_cos",data=data, color="#C2C4E2")
cyclic_month.set_title("Cyclic Encoding of Month")
cyclic_month.set_ylabel("Cosine Encoded Months")
cyclic_month.set_xlabel("Sine Encoded Months")

请添加图片描述

cyclic_day = sns.scatterplot(x='day_sin',y='day_cos',data=data, color="#C2C4E2")
cyclic_day.set_title("Cyclic Encoding of Day")
cyclic_day.set_ylabel("Cosine Encoded Day")
cyclic_day.set_xlabel("Sine Encoded Day")

请添加图片描述
接下来,我将分别处理分类属性和数字属性中的缺失值

2.5 数据清理——填补缺失值

2.5.1 分类变量

用列值的众数填补缺失值

# Get list of categorical variables
s = (data.dtypes == "object")
object_cols = list(s[s].index)

print("Categorical variables:")
print(object_cols)
Categorical variables:
['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday', 'RainTomorrow']

分类变量中的缺失值

# Missing values in categorical variables
for i in object_cols:
    print(i, data[i].isnull().sum())
Location 0
WindGustDir 10326
WindDir9am 10566
WindDir3pm 4228
RainToday 3261
RainTomorrow 3267

用众数填补缺失值

# Filling missing values with mode of the column in value

for i in object_cols:
    data.fillna({i: data[i].mode()[0]}, inplace=True)

2.5.2 数值变量

用列值的中位数填补缺失值

# Get list of neumeric variables
t = (data.dtypes == "float64")
num_cols = list(t[t].index)

print("Neumeric variables:")
print(num_cols)
Neumeric variables:
['MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine', 'WindGustSpeed', 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm', 'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am', 'Temp3pm', 'month_sin', 'month_cos', 'day_sin', 'day_cos']

数值变量中的缺失值

# Missing values in numeric variables

for i in num_cols:
    print(i, data[i].isnull().sum())
MinTemp 1485
MaxTemp 1261
Rainfall 3261
Evaporation 62790
Sunshine 69835
WindGustSpeed 10263
WindSpeed9am 1767
WindSpeed3pm 3062
Humidity9am 2654
Humidity3pm 4507
Pressure9am 15065
Pressure3pm 15028
Cloud9am 55888
Cloud3pm 59358
Temp9am 1767
Temp3pm 3609
month_sin 0
month_cos 0
day_sin 0
day_cos 0

用列值的中位数填补缺失值

# Filling missing values with median of the column in value

for i in num_cols:
    data.fillna({i: data[i].median()}, inplace=True)
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 30 columns):
 #   Column         Non-Null Count   Dtype         
---  ------         --------------   -----         
 0   Date           145460 non-null  datetime64[ns]
 1   Location       145460 non-null  object        
 2   MinTemp        145460 non-null  float64       
 3   MaxTemp        145460 non-null  float64       
 4   Rainfall       145460 non-null  float64       
 5   Evaporation    145460 non-null  float64       
 6   Sunshine       145460 non-null  float64       
 7   WindGustDir    145460 non-null  object        
 8   WindGustSpeed  145460 non-null  float64       
 9   WindDir9am     145460 non-null  object        
 10  WindDir3pm     145460 non-null  object        
 11  WindSpeed9am   145460 non-null  float64       
 12  WindSpeed3pm   145460 non-null  float64       
 13  Humidity9am    145460 non-null  float64       
 14  Humidity3pm    145460 non-null  float64       
 15  Pressure9am    145460 non-null  float64       
 16  Pressure3pm    145460 non-null  float64       
 17  Cloud9am       145460 non-null  float64       
 18  Cloud3pm       145460 non-null  float64       
 19  Temp9am        145460 non-null  float64       
 20  Temp3pm        145460 non-null  float64       
 21  RainToday      145460 non-null  object        
 22  RainTomorrow   145460 non-null  object        
 23  year           145460 non-null  int32         
 24  month          145460 non-null  int32         
 25  month_sin      145460 non-null  float64       
 26  month_cos      145460 non-null  float64       
 27  day            145460 non-null  int32         
 28  day_sin        145460 non-null  float64       
 29  day_cos        145460 non-null  float64       
dtypes: datetime64[ns](1), float64(20), int32(3), object(6)
memory usage: 31.6+ MB

2.6 绘制历年降雨量折线图

#plotting a lineplot rainfall over years
plt.figure(figsize=(12,8))
Time_series=sns.lineplot(x=data['Date'].dt.year,y="Rainfall",data=data,color="#C2C4E2")
Time_series.set_title("Rainfall Over Years")
Time_series.set_ylabel("Rainfall")
Time_series.set_xlabel("Years")

历年降雨量折线图

2.7 推算历年阵风风速

#Evauating Wind gust speed over years
colours = ["#D0DBEE", "#C2C4E2", "#EED4E5", "#D1E6DC", "#BDE2E2"]
plt.figure(figsize=(12,8))
Days_of_week=sns.barplot(x=data['Date'].dt.year,y="WindGustSpeed",data=data, errorbar=None, palette = colours)
Days_of_week.set_title("Wind Gust Speed Over Years")
Days_of_week.set_ylabel("WindGustSpeed")
Days_of_week.set_xlabel("Year")

历年阵风风速

3. 数据预处理

3.1 对分类变量进行编码标签

# Apply label encoder to each column with categorical data
label_encoder = LabelEncoder()
for i in object_cols:
    data[i] = label_encoder.fit_transform(data[i])
# Prepairing attributes of scale data
features = data.drop(['RainTomorrow', 'Date','day', 'month'], axis=1) # dropping target and extra columns
target = data['RainTomorrow']

#Set up a standard scaler for the features
col_names = list(features.columns)
s_scaler = preprocessing.StandardScaler()
features = s_scaler.fit_transform(features)
features = pd.DataFrame(features, columns=col_names) 

features.describe().T

数据描述

3.2 观察比例特征

#Detecting outliers
#looking at the scaled features
colours = ["#D0DBEE", "#C2C4E2", "#EED4E5", "#D1E6DC", "#BDE2E2"]
plt.figure(figsize=(20,10))
sns.boxenplot(data = features,palette = colours)
plt.xticks(rotation=90)
plt.show()

比例特征

#full data for 
features["RainTomorrow"] = target

#Dropping with outlier

features = features[(features["MinTemp"]<2.3)&(features["MinTemp"]>-2.3)]
features = features[(features["MaxTemp"]<2.3)&(features["MaxTemp"]>-2)]
features = features[(features["Rainfall"]<4.5)]
features = features[(features["Evaporation"]<2.8)]
features = features[(features["Sunshine"]<2.1)]
features = features[(features["WindGustSpeed"]<4)&(features["WindGustSpeed"]>-4)]
features = features[(features["WindSpeed9am"]<4)]
features = features[(features["WindSpeed3pm"]<2.5)]
features = features[(features["Humidity9am"]>-3)]
features = features[(features["Humidity3pm"]>-2.2)]
features = features[(features["Pressure9am"]< 2)&(features["Pressure9am"]>-2.7)]
features = features[(features["Pressure3pm"]< 2)&(features["Pressure3pm"]>-2.7)]
features = features[(features["Cloud9am"]<1.8)]
features = features[(features["Cloud3pm"]<2)]
features = features[(features["Temp9am"]<2.3)&(features["Temp9am"]>-2)]
features = features[(features["Temp3pm"]<2.3)&(features["Temp3pm"]>-2)]

3.3 观察无离群值的缩放特征

#looking at the scaled features without outliers

plt.figure(figsize=(20,10))
sns.boxenplot(data = features,palette = colours)
plt.xticks(rotation=90)
plt.show()

无离群值的缩放特征
看起来不错,接下来是构建人工神经网络。

4. 模型建立

4.1 数据准备(拆分为训练集和测试集)

X = features.drop(["RainTomorrow"], axis=1)
y = features["RainTomorrow"]

# Splitting test and training sets
X_train, X_test,\
		y_train, y_test = train_test_split(
				X, y, test_size = 0.2, random_state = 42)
X.shape
(127536, 26)

4.2 模型构建

#Early stopping
early_stopping = callbacks.EarlyStopping(
    min_delta=0.001, # minimium amount of change to count as an improvement
    patience=20, # how many epochs to wait before stopping
    restore_best_weights=True,
)

# Initialising the NN
model = Sequential()

# layers
model.add(Dense(units = 64, kernel_initializer = 'uniform', activation = 'relu'))
model.add(Dense(units = 32, kernel_initializer = 'uniform', activation = 'relu'))
model.add(Dense(units = 16, kernel_initializer = 'uniform', activation = 'relu'))
model.add(Dropout(0.25))
model.add(Dense(units = 8, kernel_initializer = 'uniform', activation = 'relu'))
model.add(Dropout(0.5))
model.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid'))

# Compiling the ANN
opt = Adam(learning_rate=0.00009)
model.compile(optimizer = opt, loss = 'binary_crossentropy', metrics = ['accuracy'])

# Train the ANN
history = model.fit(X_train, y_train, batch_size = 32, epochs = 100, callbacks=[early_stopping], validation_split=0.2)
Epoch 1/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 4s 937us/step - accuracy: 0.7821 - loss: 0.5575 - val_accuracy: 0.7860 - val_loss: 0.3896
Epoch 2/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 882us/step - accuracy: 0.8128 - loss: 0.4126 - val_accuracy: 0.8395 - val_loss: 0.3781
Epoch 3/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 875us/step - accuracy: 0.8275 - loss: 0.4005 - val_accuracy: 0.8423 - val_loss: 0.3703
Epoch 4/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 922us/step - accuracy: 0.8259 - loss: 0.3985 - val_accuracy: 0.8442 - val_loss: 0.3662
Epoch 5/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 876us/step - accuracy: 0.8303 - loss: 0.3930 - val_accuracy: 0.8439 - val_loss: 0.3643
Epoch 6/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 885us/step - accuracy: 0.8324 - loss: 0.3905 - val_accuracy: 0.8439 - val_loss: 0.3635
Epoch 7/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 918us/step - accuracy: 0.8313 - loss: 0.3906 - val_accuracy: 0.8447 - val_loss: 0.3623
Epoch 8/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 888us/step - accuracy: 0.8316 - loss: 0.3888 - val_accuracy: 0.8454 - val_loss: 0.3607
Epoch 9/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 921us/step - accuracy: 0.8333 - loss: 0.3876 - val_accuracy: 0.8458 - val_loss: 0.3602
Epoch 10/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 911us/step - accuracy: 0.8344 - loss: 0.3836 - val_accuracy: 0.8442 - val_loss: 0.3607
Epoch 11/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 891us/step - accuracy: 0.8327 - loss: 0.3841 - val_accuracy: 0.8456 - val_loss: 0.3589
Epoch 12/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 881us/step - accuracy: 0.8349 - loss: 0.3838 - val_accuracy: 0.8453 - val_loss: 0.3574
Epoch 13/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 907us/step - accuracy: 0.8323 - loss: 0.3846 - val_accuracy: 0.8453 - val_loss: 0.3573
Epoch 14/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 883us/step - accuracy: 0.8352 - loss: 0.3817 - val_accuracy: 0.8448 - val_loss: 0.3568
Epoch 15/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 936us/step - accuracy: 0.8322 - loss: 0.3827 - val_accuracy: 0.8448 - val_loss: 0.3570
Epoch 16/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 878us/step - accuracy: 0.8358 - loss: 0.3803 - val_accuracy: 0.8466 - val_loss: 0.3560
Epoch 17/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 928us/step - accuracy: 0.8321 - loss: 0.3809 - val_accuracy: 0.8462 - val_loss: 0.3560
Epoch 18/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 938us/step - accuracy: 0.8328 - loss: 0.3832 - val_accuracy: 0.8459 - val_loss: 0.3553
Epoch 19/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 877us/step - accuracy: 0.8342 - loss: 0.3763 - val_accuracy: 0.8451 - val_loss: 0.3560
Epoch 20/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 913us/step - accuracy: 0.8360 - loss: 0.3758 - val_accuracy: 0.8458 - val_loss: 0.3555
Epoch 21/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 893us/step - accuracy: 0.8350 - loss: 0.3780 - val_accuracy: 0.8456 - val_loss: 0.3549
Epoch 22/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 893us/step - accuracy: 0.8327 - loss: 0.3794 - val_accuracy: 0.8466 - val_loss: 0.3547
Epoch 23/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 899us/step - accuracy: 0.8331 - loss: 0.3791 - val_accuracy: 0.8460 - val_loss: 0.3550
Epoch 24/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 905us/step - accuracy: 0.8330 - loss: 0.3785 - val_accuracy: 0.8448 - val_loss: 0.3559
Epoch 25/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 880us/step - accuracy: 0.8318 - loss: 0.3790 - val_accuracy: 0.8468 - val_loss: 0.3542
Epoch 26/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 904us/step - accuracy: 0.8373 - loss: 0.3709 - val_accuracy: 0.8473 - val_loss: 0.3544
Epoch 27/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 948us/step - accuracy: 0.8322 - loss: 0.3800 - val_accuracy: 0.8472 - val_loss: 0.3535
Epoch 28/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 892us/step - accuracy: 0.8339 - loss: 0.3791 - val_accuracy: 0.8471 - val_loss: 0.3538
Epoch 29/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 917us/step - accuracy: 0.8339 - loss: 0.3755 - val_accuracy: 0.8460 - val_loss: 0.3541
Epoch 30/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 882us/step - accuracy: 0.8353 - loss: 0.3748 - val_accuracy: 0.8468 - val_loss: 0.3527
Epoch 31/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 901us/step - accuracy: 0.8372 - loss: 0.3712 - val_accuracy: 0.8462 - val_loss: 0.3536
Epoch 32/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 901us/step - accuracy: 0.8374 - loss: 0.3741 - val_accuracy: 0.8466 - val_loss: 0.3530
Epoch 33/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 877us/step - accuracy: 0.8362 - loss: 0.3740 - val_accuracy: 0.8462 - val_loss: 0.3531
Epoch 34/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 899us/step - accuracy: 0.8356 - loss: 0.3746 - val_accuracy: 0.8470 - val_loss: 0.3529
Epoch 35/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 880us/step - accuracy: 0.8330 - loss: 0.3754 - val_accuracy: 0.8466 - val_loss: 0.3528
Epoch 36/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 875us/step - accuracy: 0.8340 - loss: 0.3767 - val_accuracy: 0.8464 - val_loss: 0.3531
Epoch 37/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 924us/step - accuracy: 0.8348 - loss: 0.3743 - val_accuracy: 0.8463 - val_loss: 0.3528
Epoch 38/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 876us/step - accuracy: 0.8354 - loss: 0.3758 - val_accuracy: 0.8457 - val_loss: 0.3526
Epoch 39/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 919us/step - accuracy: 0.8378 - loss: 0.3698 - val_accuracy: 0.8470 - val_loss: 0.3526
Epoch 40/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 886us/step - accuracy: 0.8378 - loss: 0.3741 - val_accuracy: 0.8466 - val_loss: 0.3523
Epoch 41/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 894us/step - accuracy: 0.8341 - loss: 0.3770 - val_accuracy: 0.8474 - val_loss: 0.3521
Epoch 42/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 883us/step - accuracy: 0.8371 - loss: 0.3708 - val_accuracy: 0.8473 - val_loss: 0.3529
Epoch 43/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 880us/step - accuracy: 0.8375 - loss: 0.3743 - val_accuracy: 0.8457 - val_loss: 0.3536
Epoch 44/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 929us/step - accuracy: 0.8372 - loss: 0.3709 - val_accuracy: 0.8474 - val_loss: 0.3519
Epoch 45/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 907us/step - accuracy: 0.8354 - loss: 0.3722 - val_accuracy: 0.8475 - val_loss: 0.3521
Epoch 46/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 905us/step - accuracy: 0.8383 - loss: 0.3709 - val_accuracy: 0.8479 - val_loss: 0.3522
Epoch 47/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 893us/step - accuracy: 0.8356 - loss: 0.3752 - val_accuracy: 0.8464 - val_loss: 0.3528
Epoch 48/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 882us/step - accuracy: 0.8374 - loss: 0.3707 - val_accuracy: 0.8482 - val_loss: 0.3515
Epoch 49/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 2s 909us/step - accuracy: 0.8350 - loss: 0.3745 - val_accuracy: 0.8477 - val_loss: 0.3515
Epoch 50/100
2551/2551 ━━━━━━━━━━━━━━━━━━━━ 3s 911us/step - accuracy: 0.8377 - loss: 0.3720 - val_accuracy: 

4.3 绘制训练和验证损失的Loss曲线

history_df = pd.DataFrame(history.history)

plt.plot(history_df.loc[:, ['loss']], "#BDE2E2", label='Training loss')
plt.plot(history_df.loc[:, ['val_loss']],"#C2C4E2", label='Validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc="best")

plt.show()

loss曲线

4.4 绘制训练和验证的accuracy曲线

history_df = pd.DataFrame(history.history)

plt.plot(history_df.loc[:, ['accuracy']], "#BDE2E2", label='Training accuracy')
plt.plot(history_df.loc[:, ['val_accuracy']], "#C2C4E2", label='Validation accuracy')

plt.title('Training and Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

请添加图片描述

5. 模型评估

预测测试集结果

# Predicting the test set results
y_pred = model.predict(X_test)
y_pred = (y_pred > 0.5)
798/798 ━━━━━━━━━━━━━━━━━━━━ 1s 571us/step

5.1 混淆矩阵

# confusion matrix
cmap1 = sns.diverging_palette(260,-10,s=50, l=75, n=5, as_cmap=True)
plt.subplots(figsize=(9,8))
cf_matrix = confusion_matrix(y_test, y_pred)
sns.heatmap(cf_matrix/np.sum(cf_matrix), cmap = cmap1, annot = True, annot_kws = {'size':15})

混淆矩阵

5.2 分类报告

print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.88      0.94      0.91     20110
           1       0.70      0.50      0.59      5398

    accuracy                           0.85     25508
   macro avg       0.79      0.72      0.75     25508
weighted avg       0.84      0.85      0.84     25508

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1665023.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

alist网盘自动同步

alist网盘自动同步 alist可以设置目录定时转存到各个网盘&#xff0c;做到夸网盘&#xff0c;多备份的效果可以将自己挂载的alist 下的各个目录相互间进行同步&#xff0c;原理是采用alist原始api调用执行&#xff0c;同步原理是匹配文件名称是否相同&#xff0c;相同会跳过同…

excel表格里,可以把百分号放在数字前面吗?

在有些版本里是可以的&#xff0c;这样做&#xff1a; 选中数据&#xff0c;鼠标右键&#xff0c;点击设置单元格格式&#xff0c;切换到自定义&#xff0c;在右侧栏输入%0&#xff0c;点击确定就可以了。 这样设置的好处是&#xff0c;它仍旧是数值&#xff0c;并且数值大小没…

grid的常见使用场景

场景1&#xff1a;固定几列显示&#xff0c;显示不下会自动换行 <div id"container"><div class"item item-1">1</div><div class"item item-2">2</div><div class"item item-3">3</div>&l…

vue 孙组件调用父组件的方法

通过组件内的 传递方法名称&#xff0c;可以实现孙组件调用父组件。 代码如下&#xff1a; index.html <!DOCTYPE html> <html> <head><meta charset"utf-8"><script src"/framework/vue-2.7.16.min.js"></script>…

C++奇迹之旅:string类对象的修改操作

文章目录 &#x1f4dd;string类的常用接口&#x1f320; string类对象的修改操作&#x1f309;push_back&#x1f309;append&#x1f309;operator&#x1f309;insert&#x1f309;erase&#x1f309;replace&#x1f309; find&#x1f309; c_str &#x1f320;测试string…

如何到《新英格兰医学杂志》 NEJM查找下载文献

《新英格兰医学杂志》NEJM是世界上阅读、引用最广泛、影响力最大的综合性医学期刊之一。NEJM集团出版的期刊还包括NEJM Journal Watch、NEJM Catalyst及NEJM Evidence。NEJM是一份全科医学周刊&#xff0c;出版对生物医学科学与临床实践具有重要意义的一系列主题方面的医学研究…

服务器远程桌面局域网连接不上的解决方法

在企业网络环境中&#xff0c;服务器远程桌面局域网连接不上是一个常见且棘手的问题。这种问题可能导致工作效率下降&#xff0c;甚至影响业务运营。因此&#xff0c;我们需要采取专业的方法来解决这一问题。 服务器远程桌面局域网连接不上的解决方法&#xff1a; 1、确保服务器…

【北京迅为】《iTOP-3588从零搭建ubuntu环境手册》-第7章 安装VMwareTools

RK3588是一款低功耗、高性能的处理器&#xff0c;适用于基于arm的PC和Edge计算设备、个人移动互联网设备等数字多媒体应用&#xff0c;RK3588支持8K视频编解码&#xff0c;内置GPU可以完全兼容OpenGLES 1.1、2.0和3.2。RK3588引入了新一代完全基于硬件的最大4800万像素ISP&…

[单机]完美国际_V155_GM工具_VM虚拟机

[端游] 完美国际单机版V155一键端PC电脑网络游戏完美世界幻海凌云家园 本教程仅限学习使用&#xff0c;禁止商用&#xff0c;一切后果与本人无关&#xff0c;此声明具有法律效应&#xff01;&#xff01;&#xff01;&#xff01; 教程是本人亲自搭建成功的&#xff0c;绝对是…

2023版brupsuite专业破解安装

安装教程&#xff0c;分两部分&#xff1a; 1、安装java环境、参考链接JAVA安装配置----最详细的教程&#xff08;测试木头人&#xff09;_java安装教程详细-CSDN博客 2、安装2023.4版本brupsuite&#xff1a;参考链接 2023最新版—Brup_Suite安装配置----最详细的教程&…

云启未来:“云计算与网络运维精英交流群”与“独家资料”等你来探索“

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 公众号&#xff1a;网络豆云计算学堂 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a; 网络豆的主页​​​​​ &#x1f680; 云计算与运维精英交流群诚邀您的加入…

基于双经度模型的鱼眼图像畸变校正

文章目录 1. 简介2. 基本原理基本思路从目标图到半球面模型的投影从半球面模型到鱼眼图像的投影正交投影等距投影 3.实际效果示例论文中的原图去畸变 4. 有意思的玩法5. 对生成的鱼眼图去畸变 1. 简介 算法来自论文《基于双经度模型的鱼眼图像畸变矫正方法》 2. 基本原理 基本…

【复利思维 + 项目成功方程式】用1年,超越别人38年!

复利思维—每天进步1%。 一年后会比现在的自己优秀38倍。在做任何事情时都要考虑&#xff0c;这件事是否能随着时间不断积累扩大&#xff0c;不能积累价值的事情要及时调整和止损。 在这个过程中&#xff0c;千万不要陷入心理暗示的陷阱&#xff0c;尤其是越想得到的&#xf…

ElasticSearch集群环境

ElasticSearch集群环境 1、Linux单机 下载地址&#xff1a;LINUX X86_64 (elastic.co) 下载之后进行解压 tar -zxf elasticsearch-7.8.0-linux-x86_64.tar.gz 名字太长了改个名字改成es mv elasticsearch-7.8.0 es因为安全问题&#xff0c;Elasticsearch 不允许 root 用户…

一文读懂NVIDIA AI全景:从芯片到应用,全面解析未来科技

英伟达 NVIDIA AI 全景解析 NVIDIA 概述 公司概况 NVIDIA作为全球顶尖科技公司&#xff0c;早期深耕图形处理器设计制造&#xff0c;现已跃升为人工智能领域的领军者&#xff0c;产品和服务覆盖AI应用的全方位&#xff0c;引领科技潮流。 NVIDIA&#xff0c;1993年创立于美国…

Linux进程间通信 pipe 实现线程池 命名管道 实现打印日志 共享内存代码验证 消息队列 信号量

文章目录 前言管道匿名管道 pipe测试管道接口 --> 代码验证管道的4种情况管道的5种特征 线程池案例代码实现&#xff1a;ProcessPool.ccTask.hpp检测脚本makefile 命名管道代码演示&#xff1a;makefilenamedPipe.hppserver.ccclient.cc 实现日志Log.hpp 共享内存共享内存原…

机器人系统ros2-开发实践08-了解如何使用 tf2 来访问坐标帧转换(Python)

tf2 库允许你在 ROS 节点中查询两个帧之间的转换。这个查询可以是阻塞的&#xff0c;也可以是非阻塞的&#xff0c;取决于你的需求。下面是一个基本的 Python 示例&#xff0c;展示如何在 ROS 节点中使用 tf2 查询帧转换。 本教程假设您已完成tf2 静态广播器教程 (Python)和tf…

车辆运动模型中LQR代码实现

一、前言 最近看到关于架构和算法两者关系的一个描述&#xff0c;我觉得非常认同&#xff0c;分享给大家。 1、好架构起到两个作用&#xff1a;合理的分解功能、合理的适配算法&#xff1b; 2、好的架构是好的功能的必要条件&#xff0c;不是充分条件&#xff0c;一味追求架构…

海外云手机解决海外社交媒体运营难题

随着全球数字化浪潮的推进&#xff0c;海外社交媒体已成为外贸企业拓展市场、提升品牌影响力的重要阵地。Tiktok、Facebook、领英、twitter等平台以其庞大的用户基础和高度互动性&#xff0c;为企业提供了前所未有的营销机会。本文将介绍如何通过海外云手机&#xff0c;高效、快…

[优选算法]------滑动窗⼝——209. 长度最小的子数组

目录 1.题目 1.解法⼀&#xff08;暴⼒求解&#xff09;&#xff08;会超时&#xff09;&#xff1a; 2.解法⼆&#xff08;滑动窗⼝&#xff09;&#xff1a; 1.算法思路&#xff1a; 2.手撕图解 3.代码实现 1.C 2.C语言 1.题目 209. 长度最小的子数组 给定一个含有 n…