本文为为🔗365天深度学习训练营内部文章
原作者:K同学啊
一 前期准备
1.数据导入
import pandas as pd
from keras.optimizers import Adam
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from keras.models import Sequential
from keras.layers import Dense,SimpleRNN
import warnings
warnings.filterwarnings('ignore')
df = pd.read_csv('heart.csv')
2.检查数据
查看是否有空值
print(df.shape)
print(df.isnull().sum())
(303, 14) age 0 sex 0 cp 0 trestbps 0 chol 0 fbs 0 restecg 0 thalach 0 exang 0 oldpeak 0 slope 0 ca 0 thal 0 target 0 dtype: int64
二 数据预处理
1.拆分训练集
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=14)
2.数据标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)
X_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test.reshape(X_test.shape[0],X_test.shape[1],1)
array([[[ 1.44626869], [ 0.54006172], [ 0.62321699], [ 1.37686599], [ 0.83801861], [-0.48989795], [ 0.92069654], [-1.38834656], [ 1.34839972], [ 1.83944021], [-0.74161985], [ 0.18805174], [ 1.09773445]], [[-0.11901962], [ 0.54006172], [ 1.4632051 ], [-0.7179976 ], [-1.01585167], [-0.48989795], [-0.86315301], [ 0.77440436], [-0.74161985], [ 0.85288923], [-0.74161985], [-0.78354893], [ 1.09773445]],
三 构建RNN模型
model = Sequential()
model.add(SimpleRNN(200,input_shape=(X_train.shape[1],1),activation='relu'))
model.add(Dense(100,activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= simple_rnn (SimpleRNN) (None, 200) 40400 dense (Dense) (None, 100) 20100 dense_1 (Dense) (None, 1) 101 ================================================================= Total params: 60,601 Trainable params: 60,601 Non-trainable params: 0 _________________________________________________________________
四 编译模型
optimizer = Adam(learning_rate=1e-4)
# 定义损失函数为二元交叉熵(binary_crossentropy),适用于二分类任务。使用先前定义的优化器,并设置监控指标为准确率
model.compile(loss='binary_crossentropy',optimizer=optimizer,metrics='accuracy')
五 训练模型
epochs = 100
model.fit(x=X_train,y=y_train,validation_data=(X_test,y_test),verbose=1,
epochs=epochs,batch_size=128)
acc = model.history.history['accuracy']
val_acc = model.history.history['val_accuracy']
loss = model.history.history['loss']
val_loss = model.history.history['val_loss']
Epoch 1/100 3/3 [==============================] - 1s 130ms/step - loss: 0.6872 - accuracy: 0.5551 - val_loss: 0.6884 - val_accuracy: 0.5806 Epoch 2/100 3/3 [==============================] - 0s 19ms/step - loss: 0.6763 - accuracy: 0.6250 - val_loss: 0.6848 - val_accuracy: 0.6129 Epoch 3/100 3/3 [==============================] - 0s 19ms/step - loss: 0.6660 - accuracy: 0.6912 - val_loss: 0.6814 - val_accuracy: 0.6452 Epoch 4/100 3/3 [==============================] - 0s 18ms/step - loss: 0.6562 - accuracy: 0.7426 - val_loss: 0.6781 - val_accuracy: 0.6452 Epoch 5/100 3/3 [==============================] - 0s 18ms/step - loss: 0.6467 - accuracy: 0.7647 - val_loss: 0.6751 - val_accuracy: 0.6129 Epoch 6/100 3/3 [==============================] - 0s 19ms/step - loss: 0.6375 - accuracy: 0.7941 - val_loss: 0.6722 - val_accuracy: 0.6452 Epoch 7/100 3/3 [==============================] - 0s 18ms/step - loss: 0.6285 - accuracy: 0.8051 - val_loss: 0.6694 - val_accuracy: 0.6129 Epoch 8/100 3/3 [==============================] - 0s 18ms/step - loss: 0.6193 - accuracy: 0.8015 - val_loss: 0.6666 - val_accuracy: 0.6129 Epoch 9/100 3/3 [==============================] - 0s 18ms/step - loss: 0.6094 - accuracy: 0.8125 - val_loss: 0.6635 - val_accuracy: 0.5806 Epoch 10/100 3/3 [==============================] - 0s 18ms/step - loss: 0.6002 - accuracy: 0.8162 - val_loss: 0.6602 - val_accuracy: 0.6129 Epoch 11/100 3/3 [==============================] - 0s 25ms/step - loss: 0.5903 - accuracy: 0.8125 - val_loss: 0.6565 - val_accuracy: 0.5806 Epoch 12/100 3/3 [==============================] - 0s 18ms/step - loss: 0.5795 - accuracy: 0.8125 - val_loss: 0.6526 - val_accuracy: 0.5806 Epoch 13/100 3/3 [==============================] - 0s 18ms/step - loss: 0.5686 - accuracy: 0.8125 - val_loss: 0.6484 - val_accuracy: 0.6129 Epoch 14/100 3/3 [==============================] - 0s 20ms/step - loss: 0.5571 - accuracy: 0.8125 - val_loss: 0.6436 - val_accuracy: 0.6452 Epoch 15/100 3/3 [==============================] - 0s 20ms/step - loss: 0.5451 - accuracy: 0.8125 - val_loss: 0.6377 - val_accuracy: 0.6452 Epoch 16/100 3/3 [==============================] - 0s 17ms/step - loss: 0.5322 - accuracy: 0.8125 - val_loss: 0.6315 - val_accuracy: 0.6452 Epoch 17/100 3/3 [==============================] - 0s 24ms/step - loss: 0.5190 - accuracy: 0.8199 - val_loss: 0.6251 - val_accuracy: 0.6452 Epoch 18/100 3/3 [==============================] - 0s 17ms/step - loss: 0.5053 - accuracy: 0.8199 - val_loss: 0.6190 - val_accuracy: 0.6774 Epoch 19/100 3/3 [==============================] - 0s 17ms/step - loss: 0.4910 - accuracy: 0.8162 - val_loss: 0.6132 - val_accuracy: 0.6774 Epoch 20/100 3/3 [==============================] - 0s 16ms/step - loss: 0.4765 - accuracy: 0.8199 - val_loss: 0.6076 - val_accuracy: 0.6774 Epoch 21/100 3/3 [==============================] - 0s 16ms/step - loss: 0.4616 - accuracy: 0.8235 - val_loss: 0.6007 - val_accuracy: 0.6774 Epoch 22/100 3/3 [==============================] - 0s 16ms/step - loss: 0.4470 - accuracy: 0.8125 - val_loss: 0.5943 - val_accuracy: 0.6774 Epoch 23/100 3/3 [==============================] - 0s 16ms/step - loss: 0.4345 - accuracy: 0.8162 - val_loss: 0.5906 - val_accuracy: 0.6774 Epoch 24/100 3/3 [==============================] - 0s 15ms/step - loss: 0.4219 - accuracy: 0.8162 - val_loss: 0.5901 - val_accuracy: 0.7419 Epoch 25/100 3/3 [==============================] - 0s 16ms/step - loss: 0.4116 - accuracy: 0.8162 - val_loss: 0.5921 - val_accuracy: 0.7742 Epoch 26/100 3/3 [==============================] - 0s 16ms/step - loss: 0.4056 - accuracy: 0.8272 - val_loss: 0.5990 - val_accuracy: 0.7419 Epoch 27/100 3/3 [==============================] - 0s 15ms/step - loss: 0.3983 - accuracy: 0.8309 - val_loss: 0.5970 - val_accuracy: 0.7097 Epoch 28/100 3/3 [==============================] - 0s 15ms/step - loss: 0.3920 - accuracy: 0.8309 - val_loss: 0.5914 - val_accuracy: 0.7097 Epoch 29/100 3/3 [==============================] - 0s 15ms/step - loss: 0.3860 - accuracy: 0.8235 - val_loss: 0.5863 - val_accuracy: 0.7097 Epoch 30/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3802 - accuracy: 0.8235 - val_loss: 0.5724 - val_accuracy: 0.7097 Epoch 31/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3757 - accuracy: 0.8346 - val_loss: 0.5572 - val_accuracy: 0.7419 Epoch 32/100 3/3 [==============================] - 0s 20ms/step - loss: 0.3766 - accuracy: 0.8272 - val_loss: 0.5545 - val_accuracy: 0.7419 Epoch 33/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3706 - accuracy: 0.8272 - val_loss: 0.5608 - val_accuracy: 0.7419 Epoch 34/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3639 - accuracy: 0.8382 - val_loss: 0.5899 - val_accuracy: 0.7419 Epoch 35/100 3/3 [==============================] - 0s 16ms/step - loss: 0.3694 - accuracy: 0.8272 - val_loss: 0.6097 - val_accuracy: 0.7742 Epoch 36/100 3/3 [==============================] - 0s 16ms/step - loss: 0.3682 - accuracy: 0.8346 - val_loss: 0.5859 - val_accuracy: 0.7419 Epoch 37/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3567 - accuracy: 0.8309 - val_loss: 0.5680 - val_accuracy: 0.7419 Epoch 38/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3497 - accuracy: 0.8419 - val_loss: 0.5528 - val_accuracy: 0.7419 Epoch 39/100 3/3 [==============================] - 0s 16ms/step - loss: 0.3484 - accuracy: 0.8603 - val_loss: 0.5417 - val_accuracy: 0.7742 Epoch 40/100 3/3 [==============================] - 0s 22ms/step - loss: 0.3487 - accuracy: 0.8603 - val_loss: 0.5386 - val_accuracy: 0.6774 Epoch 41/100 3/3 [==============================] - 0s 22ms/step - loss: 0.3473 - accuracy: 0.8640 - val_loss: 0.5383 - val_accuracy: 0.7097 Epoch 42/100 3/3 [==============================] - 0s 19ms/step - loss: 0.3422 - accuracy: 0.8676 - val_loss: 0.5425 - val_accuracy: 0.7742 Epoch 43/100 3/3 [==============================] - 0s 19ms/step - loss: 0.3353 - accuracy: 0.8713 - val_loss: 0.5467 - val_accuracy: 0.7419 Epoch 44/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3318 - accuracy: 0.8787 - val_loss: 0.5565 - val_accuracy: 0.7419 Epoch 45/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3289 - accuracy: 0.8750 - val_loss: 0.5572 - val_accuracy: 0.7419 Epoch 46/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3263 - accuracy: 0.8750 - val_loss: 0.5548 - val_accuracy: 0.7419 Epoch 47/100 3/3 [==============================] - 0s 19ms/step - loss: 0.3227 - accuracy: 0.8787 - val_loss: 0.5520 - val_accuracy: 0.7419 Epoch 48/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3191 - accuracy: 0.8824 - val_loss: 0.5564 - val_accuracy: 0.7419 Epoch 49/100 3/3 [==============================] - 0s 19ms/step - loss: 0.3172 - accuracy: 0.8713 - val_loss: 0.5539 - val_accuracy: 0.7419 Epoch 50/100 3/3 [==============================] - 0s 20ms/step - loss: 0.3149 - accuracy: 0.8824 - val_loss: 0.5381 - val_accuracy: 0.7419 Epoch 51/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3110 - accuracy: 0.8824 - val_loss: 0.5427 - val_accuracy: 0.7419 Epoch 52/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3084 - accuracy: 0.8787 - val_loss: 0.5510 - val_accuracy: 0.7419 Epoch 53/100 3/3 [==============================] - 0s 17ms/step - loss: 0.3069 - accuracy: 0.8750 - val_loss: 0.5571 - val_accuracy: 0.7419 Epoch 54/100 3/3 [==============================] - 0s 19ms/step - loss: 0.3052 - accuracy: 0.8860 - val_loss: 0.5468 - val_accuracy: 0.7419 Epoch 55/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3024 - accuracy: 0.8787 - val_loss: 0.5347 - val_accuracy: 0.7419 Epoch 56/100 3/3 [==============================] - 0s 18ms/step - loss: 0.3010 - accuracy: 0.8787 - val_loss: 0.5417 - val_accuracy: 0.7419 Epoch 57/100 3/3 [==============================] - 0s 21ms/step - loss: 0.3013 - accuracy: 0.8860 - val_loss: 0.5496 - val_accuracy: 0.7419 Epoch 58/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2975 - accuracy: 0.8824 - val_loss: 0.5355 - val_accuracy: 0.7419 Epoch 59/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2954 - accuracy: 0.8787 - val_loss: 0.5198 - val_accuracy: 0.7419 Epoch 60/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2970 - accuracy: 0.8787 - val_loss: 0.5148 - val_accuracy: 0.7419 Epoch 61/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2991 - accuracy: 0.8824 - val_loss: 0.5187 - val_accuracy: 0.7419 Epoch 62/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2958 - accuracy: 0.8787 - val_loss: 0.5376 - val_accuracy: 0.7419 Epoch 63/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2891 - accuracy: 0.8860 - val_loss: 0.5659 - val_accuracy: 0.7419 Epoch 64/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2923 - accuracy: 0.8824 - val_loss: 0.5777 - val_accuracy: 0.7419 Epoch 65/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2892 - accuracy: 0.8824 - val_loss: 0.5560 - val_accuracy: 0.7419 Epoch 66/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2848 - accuracy: 0.8934 - val_loss: 0.5405 - val_accuracy: 0.7419 Epoch 67/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2828 - accuracy: 0.8897 - val_loss: 0.5334 - val_accuracy: 0.7419 Epoch 68/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2810 - accuracy: 0.8934 - val_loss: 0.5332 - val_accuracy: 0.7419 Epoch 69/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2792 - accuracy: 0.8934 - val_loss: 0.5307 - val_accuracy: 0.7419 Epoch 70/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2780 - accuracy: 0.8934 - val_loss: 0.5370 - val_accuracy: 0.7419 Epoch 71/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2763 - accuracy: 0.8934 - val_loss: 0.5459 - val_accuracy: 0.7419 Epoch 72/100 3/3 [==============================] - 0s 21ms/step - loss: 0.2762 - accuracy: 0.8971 - val_loss: 0.5583 - val_accuracy: 0.7419 Epoch 73/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2759 - accuracy: 0.8971 - val_loss: 0.5676 - val_accuracy: 0.7419 Epoch 74/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2764 - accuracy: 0.8934 - val_loss: 0.5715 - val_accuracy: 0.7419 Epoch 75/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2747 - accuracy: 0.8934 - val_loss: 0.5540 - val_accuracy: 0.7419 Epoch 76/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2701 - accuracy: 0.8971 - val_loss: 0.5387 - val_accuracy: 0.7419 Epoch 77/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2689 - accuracy: 0.9044 - val_loss: 0.5308 - val_accuracy: 0.7419 Epoch 78/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2701 - accuracy: 0.9081 - val_loss: 0.5241 - val_accuracy: 0.7097 Epoch 79/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2716 - accuracy: 0.9007 - val_loss: 0.5241 - val_accuracy: 0.7097 Epoch 80/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2690 - accuracy: 0.9007 - val_loss: 0.5332 - val_accuracy: 0.7097 Epoch 81/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2650 - accuracy: 0.9154 - val_loss: 0.5418 - val_accuracy: 0.7419 Epoch 82/100 3/3 [==============================] - 0s 15ms/step - loss: 0.2631 - accuracy: 0.9118 - val_loss: 0.5434 - val_accuracy: 0.7419 Epoch 83/100 3/3 [==============================] - 0s 16ms/step - loss: 0.2620 - accuracy: 0.9154 - val_loss: 0.5406 - val_accuracy: 0.7419 Epoch 84/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2603 - accuracy: 0.9154 - val_loss: 0.5395 - val_accuracy: 0.7419 Epoch 85/100 3/3 [==============================] - 0s 26ms/step - loss: 0.2588 - accuracy: 0.9154 - val_loss: 0.5497 - val_accuracy: 0.7419 Epoch 86/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2562 - accuracy: 0.9081 - val_loss: 0.5687 - val_accuracy: 0.7419 Epoch 87/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2609 - accuracy: 0.8971 - val_loss: 0.5754 - val_accuracy: 0.7419 Epoch 88/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2569 - accuracy: 0.8971 - val_loss: 0.5555 - val_accuracy: 0.7419 Epoch 89/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2532 - accuracy: 0.9081 - val_loss: 0.5399 - val_accuracy: 0.7419 Epoch 90/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2545 - accuracy: 0.9191 - val_loss: 0.5361 - val_accuracy: 0.7419 Epoch 91/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2578 - accuracy: 0.9118 - val_loss: 0.5375 - val_accuracy: 0.7419 Epoch 92/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2572 - accuracy: 0.9118 - val_loss: 0.5507 - val_accuracy: 0.7419 Epoch 93/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2516 - accuracy: 0.9118 - val_loss: 0.5715 - val_accuracy: 0.7419 Epoch 94/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2487 - accuracy: 0.9118 - val_loss: 0.5705 - val_accuracy: 0.7419 Epoch 95/100 3/3 [==============================] - 0s 18ms/step - loss: 0.2464 - accuracy: 0.9118 - val_loss: 0.5551 - val_accuracy: 0.7419 Epoch 96/100 3/3 [==============================] - 0s 20ms/step - loss: 0.2454 - accuracy: 0.9191 - val_loss: 0.5480 - val_accuracy: 0.7419 Epoch 97/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2438 - accuracy: 0.9154 - val_loss: 0.5543 - val_accuracy: 0.7419 Epoch 98/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2447 - accuracy: 0.9118 - val_loss: 0.5534 - val_accuracy: 0.7419 Epoch 99/100 3/3 [==============================] - 0s 17ms/step - loss: 0.2446 - accuracy: 0.9118 - val_loss: 0.5425 - val_accuracy: 0.7419 Epoch 100/100 3/3 [==============================] - 0s 19ms/step - loss: 0.2434 - accuracy: 0.9118 - val_loss: 0.5213 - val_accuracy: 0.7742
六 结果可视化
epochs_range = range(100)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,acc,label='training accuracy')
plt.plot(epochs_range,val_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,loss,label='training loss')
plt.plot(epochs_range,val_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()
总结:
1. 模型输入要求 RNN 输入格式:许多深度学习模型,尤其是 RNN 和 LSTM,需要输入数据的形状为三维:(样本数, 时间步数, 特征数)。这使得模型能够处理序列数据并学习时间依赖关系。 2. 数据原始形状 在标准化后,X_train 和 X_test 的形状是 (样本数, 特征数)。例如,如果 X_train 有 100 个样本和 10 个特征,则其形状为 (100, 10)。 3. 重塑的目的 重塑为三维:通过 X_train.reshape(X_train.shape[0], X_train.shape[1], 1),你将数据的形状改变为 (样本数, 特征数, 1)。这里的 1 表示特征数,在单变量情况下,只包含一个特征。 例如,假设 X_train 原本的形状是 (100, 10),重塑后将变为 (100, 10, 1),表示有 100 个样本,每个样本有 10 个时间步(特征)。 4. 适应模型结构 通过这种重塑,数据可以被 RNN 模型正确地处理,从而捕捉到特征随时间变化的模式。