数据集下载
链接:
https://pan.baidu.com/s/1qpzrSFhmyrdGmbSScN_ZXg?pwd=d1ws
提取码:d1ws
数据集读取
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
print(f'训练集-Xshape:{x_train.shape},Yshape:{y_train.shape}\n测试集-Xshape:{x_valid.shape},Yshape:{y_valid.shape}\n\n训练集样本数量:{y_train.shape[0]}\n测试集样本数量:{y_valid.shape[0]}\n\n图形大小:{int(np.sqrt(x_valid.shape[1])),int(np.sqrt(x_valid.shape[1]))}')
训练集-Xshape:(50000, 784),Yshape:(50000,)
测试集-Xshape:(10000, 784),Yshape:(10000,)
训练集样本数量:50000
测试集样本数量:10000
图形大小:(28, 28)
数据类型
print(f’数据集的数据类型:{type(x_train)}')
数据集的数据类型:<class ‘numpy.ndarray’>
训练集-图像展示-彩图-RGB
import matplotlib.pyplot as plt
fig1 = plt.figure(figsize=(4, 4))
for i in range(16):
ax = fig1.add_subplot(4,4,i+1)
ax.imshow(x_train[i].reshape(28,28))
plt.xticks([])
plt.yticks([])
plt.tight_layout()
测试集-图像展示-彩图-灰度图
fig2 = plt.figure(figsize=(4, 4))
for i in range(16):
ax = fig2.add_subplot(4,4,i+1)
ax.imshow(x_valid[i].reshape(28,28),cmap=‘gray’)
plt.xticks([])
plt.yticks([])
plt.tight_layout()