MNIST数据集介绍
MNIST数据集包含7w张带标签的手写数字图片。每次有新的分类算法出现时,常常会在改数据集测试效果。
from sklearn.datasets import fetch_openml
# 获取的mnist是一个字典
mnist = fetch_openml('mnist_784', version=1)
print(mnist.keys())
# dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
# DESCR是描述信息,data是数据集,target是标签
X, y = mnist["data"], mnist["target"]
print(X.shape)
print(y.shape)
# (70000, 784)
# (70000,)
X.shape表示一共7w张图,每个有784个特征。每个特征是28 x 28 像素中的一个点的数值,在0(白)~ 255(黑)之间。
查看其中一个图:
import matplotlib.pyplot as plt
print(y[5])
print(type(y[5]))
some_digit = X[5]
some_digit_image = some_digit.reshape(28, 28)
# cmp表示颜色映射,即实数值通过什么方法转成RGB图像。常用的还有'viridis'(很多颜色细节)、
# 'gray'(适合灰度图像)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()
# 输出: 2
# <class 'str'>
注意,这段代码很可能报错如下:
Traceback (most recent call last):
File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 3621, in get_loc
return self._engine.get_loc(casted_key)
File "pandas\_libs\index.pyx", line 136, in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\index.pyx", line 163, in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\hashtable_class_helper.pxi", line 5198, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas\_libs\hashtable_class_helper.pxi", line 5206, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 0
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:\Users\xxxx\Desktop\study\classification.py", line 24, in <module>
plt.imshow(X[0], cmap="gray")
File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\frame.py", line 3505, in __getitem__
indexer = self.columns.get_loc(key)
File "D:\Program Files\Anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 3623, in get_loc
raise KeyError(key) from err
KeyError: 0
解决:获取数据集时添加参数 as_frame=False。 这个表示以原格式返回。
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
保存数据集到本地&导入
pickle库可以序列化任何Python对象,所以可以用它保存数据集到本地。
from sklearn.datasets import fetch_openml
import pickle
# 获取数据集
# False表示以原始格式返回,每个特征是一个单独的数组。True表示返回Pandas
# DataFrame对象
# 自0.24.0(2020 年 12 月)以来,as_frame参数为auto(而不是之前的False默认选项)
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
# 保存数据集到本地
with open('mnist_data.pkl', 'wb') as f:
pickle.dump(mnist, f)
# 从本地导入数据集
with open('mnist_data.pkl', 'rb') as f:
mnist = pickle.load(f)
X, y = mnist["data"], mnist["target"]
划分训练集和测试集
MNIST已经划分好了,前6w个是训练集,后1w个是测试集。而且已经打乱过顺序了。
通过第一节我们已经知道,y的所有元素都是字符串。因为很多算法的预测结果都是数字,将标签转为数字也有助于计算error,所以使用astype(np.uint8)将y里所有元素转为8位无符号整数。
# 将数组中所有元素转为8位无符号整数
y = y.astype(np.uint8)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
Binary Classifier
Binary Classifier就是分两类。比如以数字 2 为例,我们训练一个分类器,将图片分成是2的和不是2的。
这里使用一个Stochastic Gradient Descent (SGD,随机梯度下降)分类器。这个适合高效处理较大的数据集,而且每个训练实例是单独处理的,一次一个,所以可以online learning。
y_train_2 = (y_train == 2)
y_test_2 = (y_test == 2)
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_2)
some_digit = X[100]
print(sgd_clf.predict([some_digit]))
print(y[100])
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()
其中,为了使结果可以复现,设置了random_state=42。
输出为:
[False]
5
因此可以知道,分类正确。