Keras-3-实例2-多分类问题

news2024/10/7 20:29:20

1. 多分类问题:

1.1 路透社数据集加载

路透社数据集由路透社在1986年发布,包含46个不同的主题:某些主题样本较多,某些较少,但是训练集中每个主题都至少含有10个样本。

from keras.datasets import reuters

(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)
## 与 IMDB 数据集一样,用 num_words=10000 将数据设定为前1w个最常出现的单词。
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz
2110848/2110848 [==============================] - 0s 0us/step
print(len(train_data))
print(train_data)
## 包含 8982 个训练样本
## train_data 存储形式和 IMDB 一样,后续需要做类似的数据预处理。
8982
[list([1, 2, 2, 8, 43, 10, 447, 5, 25, 207, 270, 5, 3095, 111, 16, 369, 186, 90, 67, 7, 89, 5, 19, 102, 6, 19, 124, 15, 90, 67, 84, 22, 482, 26, 7, 48, 4, 49, 8, 864, 39, 209, 154, 6, 151, 6, 83, 11, 15, 22, 155, 11, 15, 7, 48, 9, 4579, 1005, 504, 6, 258, 6, 272, 11, 15, 22, 134, 44, 11, 15, 16, 8, 197, 1245, 90, 67, 52, 29, 209, 30, 32, 132, 6, 109, 15, 17, 12])
 list([1, 3267, 699, 3434, 2295, 56, 2, 7511, 9, 56, 3906, 1073, 81, 5, 1198, 57, 366, 737, 132, 20, 4093, 7, 2, 49, 2295, 2, 1037, 3267, 699, 3434, 8, 7, 10, 241, 16, 855, 129, 231, 783, 5, 4, 587, 2295, 2, 2, 775, 7, 48, 34, 191, 44, 35, 1795, 505, 17, 12])
 list([1, 53, 12, 284, 15, 14, 272, 26, 53, 959, 32, 818, 15, 14, 272, 26, 39, 684, 70, 11, 14, 12, 3886, 18, 180, 183, 187, 70, 11, 14, 102, 32, 11, 29, 53, 44, 704, 15, 14, 19, 758, 15, 53, 959, 47, 1013, 15, 14, 19, 132, 15, 39, 965, 32, 11, 14, 147, 72, 11, 180, 183, 187, 44, 11, 14, 102, 19, 11, 123, 186, 90, 67, 960, 4, 78, 13, 68, 467, 511, 110, 59, 89, 90, 67, 1390, 55, 2678, 92, 617, 80, 1274, 46, 905, 220, 13, 4, 346, 48, 235, 629, 5, 211, 5, 1118, 7, 2, 81, 5, 187, 11, 15, 9, 1709, 201, 5, 47, 3615, 18, 478, 4514, 5, 1118, 7, 232, 2, 71, 5, 160, 63, 11, 9, 2, 81, 5, 102, 59, 11, 17, 12])
 ...
 list([1, 141, 3890, 387, 81, 8, 16, 1629, 10, 340, 1241, 850, 31, 56, 3890, 691, 9, 1241, 71, 9, 5985, 2, 2, 699, 2, 2, 2, 699, 244, 5945, 4, 49, 8, 4, 656, 850, 33, 2993, 9, 2139, 340, 3371, 1493, 9, 2, 22, 2, 1094, 687, 83, 35, 15, 257, 6, 57, 9190, 7, 4, 5956, 654, 5, 2, 6191, 1371, 4, 49, 8, 16, 369, 646, 6, 1076, 7, 124, 407, 17, 12])
 list([1, 53, 46, 957, 26, 14, 74, 132, 26, 39, 46, 258, 3614, 18, 14, 74, 134, 5131, 18, 88, 2321, 72, 11, 14, 1842, 32, 11, 123, 383, 89, 39, 46, 235, 10, 864, 728, 5, 258, 44, 11, 15, 22, 753, 9, 42, 92, 131, 728, 5, 69, 312, 11, 15, 22, 222, 2, 3237, 383, 48, 39, 74, 235, 10, 864, 276, 5, 61, 32, 11, 15, 21, 4, 211, 5, 126, 1072, 42, 92, 131, 46, 19, 352, 11, 15, 22, 710, 220, 9, 42, 92, 131, 276, 5, 59, 61, 11, 15, 22, 10, 455, 7, 1172, 137, 336, 1325, 6, 1532, 142, 971, 6463, 43, 359, 5, 4, 326, 753, 364, 17, 12])
 list([1, 227, 2406, 91, 2, 125, 2855, 21, 4, 3976, 76, 7, 4, 757, 481, 3976, 790, 5259, 5654, 9, 111, 149, 8, 7, 10, 76, 223, 51, 4, 417, 8, 1047, 91, 6917, 1688, 340, 7, 194, 9411, 6, 1894, 21, 127, 2151, 2394, 1456, 6, 3034, 4, 329, 433, 7, 65, 87, 1127, 10, 8219, 1475, 290, 9, 21, 567, 16, 1926, 24, 4, 76, 209, 30, 4033, 6655, 5654, 8, 4, 60, 8, 4, 966, 308, 40, 2575, 129, 2, 295, 277, 1071, 9, 24, 286, 2114, 234, 222, 9, 4, 906, 3994, 8519, 114, 5758, 1752, 7, 4, 113, 17, 12])]
print(len(test_data))
print(test_data)
## 包含 2246个测试样本
2246
[list([1, 4, 1378, 2025, 9, 697, 4622, 111, 8, 25, 109, 29, 3650, 11, 150, 244, 364, 33, 30, 30, 1398, 333, 6, 2, 159, 9, 1084, 363, 13, 2, 71, 9, 2, 71, 117, 4, 225, 78, 206, 10, 9, 1214, 8, 4, 270, 5, 2, 7, 748, 48, 9, 2, 7, 207, 1451, 966, 1864, 793, 97, 133, 336, 7, 4, 493, 98, 273, 104, 284, 25, 39, 338, 22, 905, 220, 3465, 644, 59, 20, 6, 119, 61, 11, 15, 58, 579, 26, 10, 67, 7, 4, 738, 98, 43, 88, 333, 722, 12, 20, 6, 19, 746, 35, 15, 10, 9, 1214, 855, 129, 783, 21, 4, 2280, 244, 364, 51, 16, 299, 452, 16, 515, 4, 99, 29, 5, 4, 364, 281, 48, 10, 9, 1214, 23, 644, 47, 20, 324, 27, 56, 2, 2, 5, 192, 510, 17, 12])
 list([1, 2768, 283, 122, 7, 4, 89, 544, 463, 29, 798, 748, 40, 85, 306, 28, 19, 59, 11, 82, 84, 22, 10, 1315, 19, 12, 11, 82, 52, 29, 283, 1135, 558, 2, 265, 2, 6607, 8, 6607, 118, 371, 10, 1503, 281, 4, 143, 4811, 760, 50, 2088, 225, 139, 683, 4, 48, 193, 862, 41, 967, 1999, 30, 1086, 36, 8, 28, 602, 19, 32, 11, 82, 5, 4, 89, 544, 463, 41, 30, 6273, 13, 260, 951, 6607, 8, 69, 1749, 18, 82, 41, 30, 306, 3342, 13, 4, 37, 38, 283, 555, 649, 18, 82, 13, 1721, 282, 9, 132, 18, 82, 41, 30, 385, 21, 4, 169, 76, 36, 8, 107, 4, 106, 524, 10, 295, 3825, 2, 2476, 6, 3684, 6940, 4, 1126, 41, 263, 84, 395, 649, 18, 82, 838, 1317, 4, 572, 4, 106, 13, 25, 595, 2445, 40, 85, 7369, 518, 5, 4, 1126, 51, 115, 680, 16, 6, 719, 250, 27, 429, 6607, 8, 6940, 114, 343, 84, 142, 20, 5, 1145, 1538, 4, 65, 494, 474, 27, 69, 445, 11, 1816, 6607, 8, 109, 181, 2768, 2, 62, 1810, 6, 624, 901, 6940, 107, 4, 1126, 34, 524, 4, 6940, 1126, 41, 447, 7, 1427, 13, 69, 251, 18, 872, 876, 1539, 468, 9063, 242, 5, 646, 27, 1888, 169, 283, 87, 9, 10, 2, 260, 182, 122, 678, 306, 13, 4, 99, 216, 7, 89, 544, 64, 85, 2333, 6, 195, 7254, 6337, 268, 609, 4, 195, 41, 1017, 2765, 2, 4, 73, 706, 2, 92, 4, 91, 3917, 36, 8, 51, 144, 23, 1858, 129, 564, 13, 269, 678, 115, 55, 866, 189, 814, 604, 838, 117, 380, 595, 951, 320, 4, 398, 57, 2233, 7411, 269, 274, 87, 6607, 8, 787, 283, 34, 596, 661, 5467, 13, 2362, 1816, 90, 2, 84, 22, 2202, 1816, 54, 748, 6607, 8, 87, 62, 6154, 84, 161, 5, 1208, 480, 4, 2, 416, 6, 538, 122, 115, 55, 129, 1104, 1445, 345, 389, 31, 4, 169, 76, 36, 8, 787, 398, 7, 4, 2, 1507, 64, 8862, 22, 125, 2, 9, 2876, 172, 399, 9, 2, 5206, 9, 2, 122, 36, 8, 6642, 172, 247, 100, 97, 6940, 34, 75, 477, 541, 4, 283, 182, 4, 2, 295, 301, 2, 125, 2, 6607, 8, 77, 57, 445, 283, 1998, 217, 31, 380, 704, 51, 77, 2, 509, 5, 476, 9, 2876, 122, 115, 853, 6, 1061, 52, 10, 2, 2, 1308, 5, 4, 283, 182, 36, 8, 5296, 114, 30, 531, 6, 6376, 9, 2470, 529, 13, 2, 2, 58, 529, 7, 2148, 2, 185, 1028, 240, 5296, 1028, 949, 657, 57, 6, 1046, 283, 36, 8, 6607, 8, 4, 2217, 34, 9177, 13, 10, 4910, 5, 4, 141, 283, 120, 50, 2877, 7, 1049, 43, 10, 181, 283, 734, 115, 55, 3356, 476, 6, 2195, 10, 73, 120, 50, 41, 6877, 169, 87, 6607, 8, 107, 144, 23, 129, 120, 169, 87, 33, 2409, 30, 1888, 1171, 161, 4, 294, 517, 23, 2, 25, 398, 9, 2060, 283, 21, 4, 236, 36, 8, 143, 169, 87, 641, 1569, 28, 69, 61, 376, 514, 90, 1249, 62, 2, 13, 4, 2217, 696, 122, 404, 2936, 22, 134, 6, 187, 514, 10, 1249, 107, 4, 96, 1043, 1569, 13, 10, 184, 28, 61, 376, 514, 268, 680, 4, 320, 6, 154, 6, 69, 160, 514, 10, 1249, 27, 4, 153, 5, 52, 29, 36, 8, 6607, 8, 612, 408, 10, 3133, 283, 76, 27, 1504, 31, 169, 951, 2, 122, 36, 8, 283, 236, 62, 641, 84, 618, 2, 22, 8417, 8409, 9, 274, 7322, 399, 7587, 51, 115, 55, 45, 4044, 31, 4, 490, 558, 36, 8, 224, 2, 115, 57, 85, 1655, 2671, 5, 283, 6, 4, 37, 38, 7, 1797, 185, 77, 4446, 4, 555, 298, 77, 240, 2, 7, 327, 652, 194, 8773, 6233, 34, 2, 5463, 4884, 1297, 6, 240, 260, 458, 87, 6, 134, 514, 10, 1249, 22, 196, 514, 4, 37, 38, 309, 213, 54, 207, 8577, 25, 134, 139, 89, 283, 494, 555, 22, 4, 2217, 6, 2172, 4278, 434, 835, 22, 3598, 3746, 434, 835, 7, 48, 6607, 8, 618, 225, 586, 333, 122, 572, 126, 2768, 1998, 62, 133, 6, 2458, 233, 28, 602, 188, 5, 4, 704, 1998, 62, 45, 885, 281, 4, 48, 193, 760, 36, 8, 115, 680, 78, 58, 109, 95, 6, 1732, 1516, 281, 4, 225, 760, 17, 12])
 list([1, 4, 309, 2276, 4759, 5, 2015, 403, 1920, 33, 1575, 1627, 1173, 87, 13, 536, 78, 6490, 399, 7, 2068, 212, 10, 634, 179, 8, 137, 5602, 7, 2775, 33, 30, 1015, 43, 33, 5602, 50, 489, 4, 403, 6, 96, 399, 7, 1953, 3587, 8427, 6603, 4132, 3669, 8180, 7163, 9, 2015, 8, 2, 2, 1683, 791, 5, 740, 220, 707, 13, 4, 634, 634, 54, 1405, 6331, 4, 361, 182, 24, 511, 972, 137, 403, 1920, 529, 6, 96, 3711, 399, 41, 30, 2776, 21, 10, 8491, 2002, 503, 5, 188, 6, 353, 26, 2474, 21, 432, 4, 4234, 23, 3288, 435, 34, 737, 6, 246, 7528, 274, 1173, 1627, 87, 13, 399, 992, 27, 274, 403, 87, 2631, 85, 480, 52, 2015, 403, 820, 13, 10, 139, 9, 115, 949, 609, 890, 819, 6, 812, 593, 7, 576, 7, 194, 2329, 216, 2, 8, 2, 8, 634, 33, 768, 2085, 593, 4, 403, 1920, 185, 9, 107, 403, 87, 2, 107, 1635, 410, 4, 682, 189, 161, 1635, 762, 274, 5319, 115, 30, 43, 389, 410, 4, 682, 107, 1635, 762, 456, 36, 8, 184, 4057, 95, 1854, 107, 403, 87, 302, 2, 8, 129, 100, 756, 7, 3288, 96, 298, 55, 370, 731, 866, 189, 115, 949, 9695, 115, 949, 343, 756, 2, 9, 115, 949, 343, 756, 2509, 36, 8, 17, 12])
 ...
 list([1, 1809, 124, 53, 653, 26, 39, 5439, 18, 14, 5893, 18, 155, 177, 53, 544, 26, 39, 19, 5121, 18, 14, 19, 6382, 18, 280, 3882, 11, 14, 3123, 32, 11, 695, 3614, 47, 11, 14, 3615, 63, 11, 430, 3259, 44, 11, 14, 61, 11, 17, 12])
 list([1, 5586, 2, 71, 8, 23, 166, 344, 10, 78, 13, 68, 80, 467, 606, 6, 261, 5, 146, 93, 124, 4, 166, 75, 3603, 2, 5907, 265, 8692, 1251, 2, 297, 1127, 195, 9, 621, 575, 1080, 5907, 7, 378, 104, 421, 648, 20, 5, 4, 49, 2, 8, 1708, 28, 4, 303, 163, 524, 10, 1220, 6, 455, 4, 326, 685, 6, 2, 422, 71, 142, 73, 863, 62, 75, 3603, 6, 4, 326, 166, 2, 34, 1652, 3603, 6, 4, 166, 4, 49, 8, 17, 12])
 list([1, 706, 209, 658, 4, 37, 38, 309, 484, 4, 1434, 6, 933, 4, 89, 709, 377, 101, 28, 4, 143, 511, 101, 5, 47, 758, 15, 90, 2388, 7, 809, 6, 444, 2035, 4, 911, 5, 709, 198, 1997, 634, 3644, 3798, 2305, 8, 1486, 6, 674, 480, 10, 990, 309, 4008, 2190, 2305, 1849, 24, 68, 583, 242, 5, 4, 143, 709, 364, 7376, 41, 30, 13, 706, 6, 837, 4, 377, 101, 6, 631, 28, 47, 758, 15, 36, 1413, 107, 4, 377, 101, 62, 47, 758, 15, 634, 114, 713, 888, 1412, 6, 343, 37, 38, 1116, 95, 1136, 269, 43, 1488, 1170, 6, 226, 2, 4, 377, 101, 136, 143, 1032, 4, 89, 709, 377, 101, 1217, 30, 478, 97, 47, 948, 15, 90, 4594, 2, 5853, 41, 30, 13, 706, 6, 455, 4, 465, 474, 6, 837, 634, 6, 2069, 4, 709, 377, 101, 28, 47, 758, 15, 7, 463, 29, 89, 1017, 97, 148, 16, 6, 47, 948, 15, 4, 48, 511, 377, 101, 23, 47, 758, 15, 161, 5, 4, 47, 12, 20, 7424, 7978, 386, 240, 2305, 2634, 24, 10, 181, 1475, 7, 194, 534, 21, 709, 364, 756, 33, 30, 4, 386, 404, 36, 118, 4, 2190, 24, 4, 911, 7, 1116, 23, 24, 4, 37, 38, 377, 101, 1976, 42, 9964, 6, 127, 122, 9, 7609, 1136, 692, 13, 37, 38, 1116, 446, 69, 4, 234, 709, 7614, 1320, 13, 126, 1006, 5, 338, 458, 2305, 8, 4, 1136, 911, 23, 4, 307, 2016, 36, 8, 634, 23, 325, 2863, 4, 820, 9, 129, 2767, 40, 836, 85, 1523, 17, 12])]
print(list(set(train_labels)))
print(list(set(test_labels)))
## 训练和测试标签都有46个类别
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]

1.2 数据集预处理:

和 IMDB 数据集一样,将序列数据向量化,转化为张量。

包括 样本数据向量化 和 标签向量化。

数据向量化

import numpy as np

def vectorize_sequences(sequences, dimension=10000):
    results = np.zeros((len(sequences), dimension))
    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1.
    return results
## 生成一个 n*10000 的2D张量,对每一行而言,其中值为1.的位置表示该位置对应的单词在句子(sequence)中存在

## 将训练数据和测试数据向量化
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
print(x_test.shape)
print(x_test)
(2246, 10000)
[[0. 1. 1. ... 0. 0. 0.]
 [0. 1. 1. ... 0. 0. 0.]
 [0. 1. 1. ... 0. 0. 0.]
 ...
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 1. ... 0. 0. 0.]
 [0. 1. 1. ... 0. 0. 0.]]

标签向量化 (one-hot encode)

from keras.utils.np_utils import to_categorical

one_hot_train_labels = to_categorical(train_labels)
one_hot_test_labels = to_categorical(test_labels)
print(one_hot_test_labels.shape)
print(one_hot_test_labels)
(2246, 46)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

1.3 构建网络:

与 IDMB 二分类问题相比,多分类问题的关键在于 输出类别增多,输出空间的维度变大(此处由 二分类的2维 变成了 多分类的46维);

因此,如果还是和 IMDB 的网络一样使用16维的中间层 (16个隐藏单元),那么可能无法获得足够的信息,所以需要维度更大的中间层 (此处用 32个隐藏单元,与书中给出的64个隐藏单元不同,因为用64个的话,模型刚开始训练就出现了过拟合现象,可能是硬件升级导致的);

## 定义模型
from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(32, activation="relu", input_shape=(10000,))) 
model.add(layers.Dense(32, activation="relu")) ## 中间层包含32个隐藏单元
model.add(layers.Dense(46, activation="softmax")) ## 输出层包含46个隐藏单元,表示每个样本都会输出一个46维的向量,向量中的每个元素代表不同的输出类别
## 最后一层用 softmax 激活,将输出46个不同类别上的概率分布,总和为1.
## 编译模型
model.compile(optimizer="rmsprop",
              loss="categorical_crossentropy",
              metrics=["accuracy"])
## 损失函数用的是 分类交叉熵函数 (categorical_crossentropy)

1.4 训练模型:

从训练集中留出 1000 个样本作为验证集,剩下的7982作为训练集

## 预留出验证集
x_val = x_train[:1000]
partial_x_train = x_train[1000:]

y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]
## 训练模型
## 训练20次,批量大小为512
import tensorflow as tf

tf.config.run_functions_eagerly(True)
## 该版本的 tensorflow 需要加上这一行,否则报错:
## ValueError: Creating variables on a non-first call to a function decorated with tf.function.

history = model.fit(partial_x_train,
                    partial_y_train,
                    batch_size=512,
                    epochs=20,
                    validation_data=(x_val, y_val))
Epoch 1/20
16/16 [==============================] - 1s 44ms/step - loss: 2.9901 - accuracy: 0.4608 - val_loss: 2.2753 - val_accuracy: 0.5610
Epoch 2/20
16/16 [==============================] - 1s 32ms/step - loss: 1.9436 - accuracy: 0.6143 - val_loss: 1.7261 - val_accuracy: 0.6250
Epoch 3/20
16/16 [==============================] - 1s 32ms/step - loss: 1.5143 - accuracy: 0.6794 - val_loss: 1.4657 - val_accuracy: 0.6680
Epoch 4/20
16/16 [==============================] - 0s 30ms/step - loss: 1.2668 - accuracy: 0.7226 - val_loss: 1.3284 - val_accuracy: 0.6970
Epoch 5/20
16/16 [==============================] - 0s 30ms/step - loss: 1.0926 - accuracy: 0.7625 - val_loss: 1.2174 - val_accuracy: 0.7380
Epoch 6/20
16/16 [==============================] - 0s 30ms/step - loss: 0.9552 - accuracy: 0.7945 - val_loss: 1.1462 - val_accuracy: 0.7520
Epoch 7/20
16/16 [==============================] - 0s 30ms/step - loss: 0.8368 - accuracy: 0.8191 - val_loss: 1.0916 - val_accuracy: 0.7600
Epoch 8/20
16/16 [==============================] - 0s 30ms/step - loss: 0.7341 - accuracy: 0.8421 - val_loss: 1.0458 - val_accuracy: 0.7740
Epoch 9/20
16/16 [==============================] - 0s 30ms/step - loss: 0.6430 - accuracy: 0.8601 - val_loss: 1.0127 - val_accuracy: 0.7820
Epoch 10/20
16/16 [==============================] - 0s 30ms/step - loss: 0.5613 - accuracy: 0.8748 - val_loss: 0.9958 - val_accuracy: 0.7890
Epoch 11/20
16/16 [==============================] - 0s 30ms/step - loss: 0.4923 - accuracy: 0.8895 - val_loss: 0.9751 - val_accuracy: 0.7970
Epoch 12/20
16/16 [==============================] - 0s 29ms/step - loss: 0.4319 - accuracy: 0.9042 - val_loss: 0.9602 - val_accuracy: 0.8040
Epoch 13/20
16/16 [==============================] - 0s 30ms/step - loss: 0.3787 - accuracy: 0.9163 - val_loss: 0.9512 - val_accuracy: 0.8060
Epoch 14/20
16/16 [==============================] - 0s 30ms/step - loss: 0.3341 - accuracy: 0.9257 - val_loss: 0.9827 - val_accuracy: 0.7990
Epoch 15/20
16/16 [==============================] - 1s 32ms/step - loss: 0.2971 - accuracy: 0.9322 - val_loss: 0.9512 - val_accuracy: 0.8100
Epoch 16/20
16/16 [==============================] - 1s 33ms/step - loss: 0.2663 - accuracy: 0.9391 - val_loss: 0.9596 - val_accuracy: 0.8050
Epoch 17/20
16/16 [==============================] - 1s 32ms/step - loss: 0.2379 - accuracy: 0.9417 - val_loss: 0.9816 - val_accuracy: 0.7990
Epoch 18/20
16/16 [==============================] - 0s 30ms/step - loss: 0.2135 - accuracy: 0.9475 - val_loss: 0.9874 - val_accuracy: 0.7990
Epoch 19/20
16/16 [==============================] - 1s 33ms/step - loss: 0.1969 - accuracy: 0.9504 - val_loss: 0.9853 - val_accuracy: 0.8070
Epoch 20/20
16/16 [==============================] - 1s 32ms/step - loss: 0.1808 - accuracy: 0.9498 - val_loss: 1.0121 - val_accuracy: 0.8020

1.5 可视化监控指标

history_dict = history.history
history_dict.keys()
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
## 训练损失和验证损失
import matplotlib.pyplot as plt

loss_values = history_dict["loss"]
val_loss_values = history_dict["val_loss"]

epochs = range(1, len(loss_values)+1)

plt.plot(epochs, loss_values, "bo", label="Training loss") ## "bo" 表示蓝色圆点
plt.plot(epochs, val_loss_values, "b", label="Validation loss") ## "bo" 表示蓝色实线
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

plt.show()

在这里插入图片描述

## 练精度和验证精度
acc_values = history_dict["accuracy"]
val_acc_values = history_dict["val_accuracy"]

plt.plot(epochs, acc_values, "bo", label="Training accuracy") ## "bo" 表示蓝色圆点
plt.plot(epochs, val_acc_values, "b", label="Validation accuracy") ## "bo" 表示蓝色实线
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

plt.show()

在这里插入图片描述

1.6 从头开始训练模型:

模型中间层没有采用书中给的64个隐藏单元 (因为用64个隐藏单元,模型可能刚开始就出现了过拟合。可能是因为硬件升级导致的),所以这里用了32个隐藏单元,最终结果类似于书中给出的结果。

从可视化结果来看,训练次数达到13次是,模型达到最佳效果。所以用 epochs=13 重新训练模型,并在测试集上测试评估。

## 构建模型
model = models.Sequential()
model.add(layers.Dense(32, activation="relu", input_shape=(10000,))) 
model.add(layers.Dense(32, activation="relu")) 
model.add(layers.Dense(46, activation="softmax")) 

## 编译模型
model.compile(optimizer="rmsprop",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

## 将 epochs 设置为 13 重新训练模型,并在测试集上做测试
model.fit(partial_x_train,
          partial_y_train,
          epochs=13,
          batch_size=512,
          validation_data=(x_val, y_val))
results = model.evaluate(x_test, one_hot_test_labels)
Epoch 1/13
 1/16 [>.............................] - ETA: 1s - loss: 3.8319 - accuracy: 0.0098

/opt/homebrew/Caskroom/miniforge/base/envs/tensorflow_gpu/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py:254: UserWarning: Even though the `tf.config.experimental_run_functions_eagerly` option is set, this option does not apply to tf.data functions. To force eager execution of tf.data functions, please use `tf.data.experimental.enable_debug_mode()`.
  warnings.warn(


16/16 [==============================] - 1s 39ms/step - loss: 3.0634 - accuracy: 0.4962 - val_loss: 2.3043 - val_accuracy: 0.6060
Epoch 2/13
16/16 [==============================] - 1s 32ms/step - loss: 1.8882 - accuracy: 0.6600 - val_loss: 1.6528 - val_accuracy: 0.6600
Epoch 3/13
16/16 [==============================] - 1s 33ms/step - loss: 1.4111 - accuracy: 0.7201 - val_loss: 1.3767 - val_accuracy: 0.7130
Epoch 4/13
16/16 [==============================] - 0s 30ms/step - loss: 1.1538 - accuracy: 0.7573 - val_loss: 1.2404 - val_accuracy: 0.7290
Epoch 5/13
16/16 [==============================] - 0s 31ms/step - loss: 0.9805 - accuracy: 0.7831 - val_loss: 1.1452 - val_accuracy: 0.7510
Epoch 6/13
16/16 [==============================] - 0s 30ms/step - loss: 0.8459 - accuracy: 0.8138 - val_loss: 1.0883 - val_accuracy: 0.7630
Epoch 7/13
16/16 [==============================] - 0s 30ms/step - loss: 0.7324 - accuracy: 0.8380 - val_loss: 1.0423 - val_accuracy: 0.7740
Epoch 8/13
16/16 [==============================] - 0s 30ms/step - loss: 0.6342 - accuracy: 0.8602 - val_loss: 1.0112 - val_accuracy: 0.7840
Epoch 9/13
16/16 [==============================] - 1s 33ms/step - loss: 0.5470 - accuracy: 0.8821 - val_loss: 0.9850 - val_accuracy: 0.7980
Epoch 10/13
16/16 [==============================] - 1s 32ms/step - loss: 0.4743 - accuracy: 0.9002 - val_loss: 0.9823 - val_accuracy: 0.8010
Epoch 11/13
16/16 [==============================] - 1s 33ms/step - loss: 0.4105 - accuracy: 0.9131 - val_loss: 0.9727 - val_accuracy: 0.8020
Epoch 12/13
16/16 [==============================] - 1s 34ms/step - loss: 0.3582 - accuracy: 0.9228 - val_loss: 0.9602 - val_accuracy: 0.8100
Epoch 13/13
16/16 [==============================] - 1s 32ms/step - loss: 0.3140 - accuracy: 0.9303 - val_loss: 0.9571 - val_accuracy: 0.8160
71/71 [==============================] - 1s 18ms/step - loss: 1.0241 - accuracy: 0.7805
print(results)
[1.0240509510040283, 0.7804986834526062]

和完全随机地预测相比(预测精度大约19%),模型的预测精度相当好了(大约79%)。

## 随机预测的精度
import copy
test_labels_copy = copy.copy(test_labels) ## 深拷贝 test_labels 一份,得到的备份不受原来的变化的影响
np.random.shuffle(test_labels_copy)
hit_array = np.array(test_labels) == np.array(test_labels_copy) ## 如果相等,赋值为1.
float(np.sum(hit_array)) / len(test_labels)
0.18833481745325023

1.7 用训练好的模型在新数据上进行预测:

predictions = model.predict(x_test) ## 在测试集上进行预测
19/71 [=======>......................] - ETA: 0s

/opt/homebrew/Caskroom/miniforge/base/envs/tensorflow_gpu/lib/python3.10/site-packages/tensorflow/python/data/ops/structured_function.py:254: UserWarning: Even though the `tf.config.experimental_run_functions_eagerly` option is set, this option does not apply to tf.data functions. To force eager execution of tf.data functions, please use `tf.data.experimental.enable_debug_mode()`.
  warnings.warn(


71/71 [==============================] - 0s 5ms/step
## 每个样本的输出都是长度为46的向量
print(predictions[0].shape)

## 每个样本的输出向量的所有元素之和是1.
print(np.sum(predictions[0]))

## 每个样本中,最大的元素(概率最大)对应的类别就是预测类别
print(np.argmax(predictions[0]))
(46,)
1.0000001
3

1.8 小结:

  1. 如果标签的编码方式不是one-hot,而是直接转化为整数张量 (标签对应分类编码),(对于整数标签)那么对应的损失函数要换成 “sparse_categorical_crossentropy”;

  2. 如果有N个类别,最后的输出层应该是大小为 N 的Dense层(隐藏单元个数为 N);

  3. 对于单标签、多分类问题,网络最后一层应该用 softmax 激活,从而得到 N 个输出类别上的概率分布 (总和为1.);

  4. 多分类问题的损失函数应该使用 分类交叉熵 (categorical_crossentropy) 函数;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/621846.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ViewOverlay-加蒙层真的一种实现方式

一、ViewOverlay能实现什么? 在Android中,ViewOverlay是一个特殊的视图层,可以在一个视图的上方添加和管理附加的视图层,而不会干扰原始视图的布局和交互。它提供了一种方便的方式来在运行时添加、移除或修改视图层,而…

chatgpt赋能python:Python嵌入SEO

Python嵌入SEO Python是一种高级编程语言,由于其简单易学和广泛应用的特点,已经成为了许多工程师的首选语言。随着互联网发展的趋势,现代的SEO已经不再是简单的关键词填充和链接堆积,而是需要更复杂的优化方式,这时候…

Sentinel在k8s部署

一、Sentinel Dashboard在k8s部署 官方jar包下载 由于sentinel dashboard官方没有提供镜像下载,需从sentinel官方下载sentinel dashboard的jar包,这里选择1.8.0进行下载。注意与springboot版本的兼容性。 打镜像并上传自己镜像仓库 在自己项目中添加…

mac(M1)芯片安装Stable-diffusion-webui

背景:听同事说这个都是在GPU上跑的,cpu跑这个比较费劲。我本地mac跑这个,也是为了调试一些相关的插件和api。为了开发方便点。当然确实提吃内存的。 目录 一、Stable-diffusion-webui 项目地址和官方安装方式 二、自己的安装方式 2.1、更…

自定义注解,基于redis实现分布式锁

一、如何实现自定义注解 1.1、注解的基础知识 实现自定义注解其实很简单,格式基本都差不多。也就参数可能变一变。 Retention:取值决定了注解在什么时候生效,一般都是取运行时,也就是RetentionPolicy.RUNTIME。 Target&#xff…

Unreal5 第三人称射击游戏 射击功能实现2

上一篇我们实现了角色射击相关的动画以及切换逻辑,并将武器相关的模型添加到角色身上。 这一篇开始制作武器相关的功能。 制作子弹父类 首先创建一个actor类,命名为BP_Bullet,这个作为子弹的通用父类,在里面创建子弹通用的功能实…

测试相关知识

测试基础知识 1. 测试基本理念2. 软件测试的分类2.1 程序是否运行2.2 测试时间段划分2.3 是否涉及实现2.4 其它测试2.5 当前流程的测试概念 3. 测试设计方法4. 参考书籍 1. 测试基本理念 软件测试的定义:软件测试是使用人工或自动的手段来运行或测定某个软件系统的…

chatgpt赋能python:Python年龄换算:如何根据Python版本算出“年龄”?

Python年龄换算:如何根据Python版本算出“年龄”? Python是一种高级编程语言,享有强大、易读、易用和可扩展性等各种优点。它是许多开发者使用的首选语言,尤其在数据科学和机器学习领域中备受推崇。 但是,Python几乎…

网络安全工具合集

首先,恭喜你发现了宝藏。 本文章集成了全网优秀的开源攻防武器项目,包含: 信息收集工具(自动化利用工具、资产发现工具、目录扫描工具、子域名收集工具、指纹识别工具、端口扫描工具、各种插件....etc...) 漏洞利用…

轮廓检测及功能

目录 一、实验介绍二、实验步骤三、实验任务任务一:轮廓特征练习一: 找到每个轮廓的方向任务二:边界矩形练习二: 围绕轮廓裁剪图像 一、实验介绍 1. 实验内容 本实验将学习轮廓检测及功能。 2. 实验要点 生成二进制图像来查找轮廓找到并画出轮廓轮廓…

面试题--12

1.MySQL 的内连接、左连接、右连接有什么区别? 2.MySQL 索引是怎么实现的? 3.索引的优点和缺点—查询频率高的字段 4.B树的特点和作用----必须先看 下面的网址 5.什么是聚集索引和非聚集索引 6.索引的分类 7.什么是最左前缀法则以及如何设计最左法则 8.怎…

图论基础和表示

一、概念及其介绍 图论(Graph Theory)是离散数学的一个分支,是一门研究图(Graph)的学问。 图是用来对对象之间的成对关系建模的数学结构,由"节点"或"顶点"(Vertex)以及连接这些顶点的"边"(Edge&a…

chatgpt赋能python:Python开源免费——为什么它是SEO世界的最佳选择

Python开源免费——为什么它是SEO世界的最佳选择 介绍 Python是一种流行的程序设计语言,拥有众多的开源库和响应式社区。它广泛应用于机器学习、数据科学、Web开发、系统自动化等领域。Python不仅易于学习和使用,而且完全免费并且开源,使得…

Linux线程的同步与互斥(二) 条件变量+信号量

文章目录 二、线程同步条件变量1、条件变量的概念2、同步概念与竞态条件3、条件变量函数初始化4、条件变量函数销毁5、条件变量函数等待6、条件变量函数唤醒等待生产者消费者模型1、理论部分2、“3 2 1”原则3、基于阻塞队列的生产者消费者模型 POSIX信号量1、信号量的概念2、信…

时间序列学习 经典案例(5)【stability_selection】股票数据特征分析与特征选择

0.环境配置 本文用到的环境是: jupyter 略python 3.9stability_selection 略,见下文。scikit-learn 1.2.2seaborn 0.12.2pandas 1.3.5numpy 1.22.3matplotlib 3.6.3tushare 1.2.89baostock 00.8.80yfinance 0.2.18 1.stability_selection模块安装【问…

【MySQL数据库 | 番外篇】 聚合函数

前言: 聚合函数是分组查询中一个重要的组成部分,想要利用分组查询,就要对聚合函数有不错的掌握,因此我们在这里开一篇番外,讲解SQL语法中的聚合函数 聚合函数: 聚合函数是SQL中一种特殊的函数,…

大模型信息提取、文本生成、视觉语音应用

448页新书《基础模型自然语言处理》,详述大模型在信息提取文本生成视觉语音应用。 Dr. Gerhard Paa 等人合著的《Foundation Models for Natural Language Processing》 一书系统介绍基础模型研究和应用的全面概述,而且是目前对此方面研究最新的综述。 …

【Java】Java核心要点总结:59

文章目录 1. 线程的run()和start()有什么区别,为什么不直接调用run()2. synchronized是什么,以及原理3. Java中如何实现多线程的通讯和协作4. Volatile有什么特点,为什么能够保证变量的可见性5. 为什么说synchronized是一个悲观锁&#xff0c…

chatgpt赋能python:Python开发环境的下载方法

Python开发环境的下载方法 Python是一种高级的编程语言,受到广泛的社区和商业支持。它用于数据分析、人工智能和Web开发等领域,成为业界最流行的编程语言之一。搭建Python开发环境需要安装解释器、编辑器、包管理器和库,下面介绍Python开发环…

chatgpt赋能python:用Python开发在线电影播放网站如何进行SEO

用Python开发在线电影播放网站如何进行SEO 随着人们对于电影、电视剧等视频娱乐的需求日益增长,越来越多的在线电影播放网站涌现出来。作为开发者,如何通过搜索引擎优化(SEO)来使你的在线电影播放网站具有更好的可见度&#xff0…