Python PyTorch 获取 MNIST 数据
- 1 PyTorch 获取 MNIST 数据
- 2 PyTorch 保存 MNIST 数据
- 3 PyTorch 显示 MNIST 数据
1 PyTorch 获取 MNIST 数据
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
def mnist_get():
print(torch.__version__)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_image = train_data.data.numpy()
train_label = train_data.targets.numpy()
test_image = test_data.data.numpy()
test_label = test_data.targets.numpy()
2 PyTorch 保存 MNIST 数据
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
def mnist_save(mnist_path):
print(torch.__version__)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_image = train_data.data.numpy()
train_label = train_data.targets.numpy()
test_image = test_data.data.numpy()
test_label = test_data.targets.numpy()
np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)
mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)
3 PyTorch 显示 MNIST 数据
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
def mnist_show(mnist_path):
data = np.load(mnist_path)
image = data['train_data'][0:100]
label = data['train_label'].reshape(-1, )
plt.figure(figsize = (10, 10))
for i in range(100):
print('%f, %f' % (i, label[i]))
plt.subplot(10, 10, i + 1)
plt.imshow(image[i])
plt.show()
mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)