1 前言
2 问题背景–以图片分类为例
import torch
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def load_data_fashion_mnist(input_batch_size):
Load FashionMNIST dataset and return data iterators.
input_batch_size (int): 批量大小 / Batch size.
Tuple[DataLoader, DataLoader]: 训练数据迭代器和测试数据迭代器 / Training and test data iterators.
transform = transforms.ToTensor()
train_dataset = FashionMNIST(
root='./data', train=True, transform=transform, download=True
test_dataset = FashionMNIST(
root='./data', train=False, transform=transform, download=True
output_train_iter = DataLoader(
train_dataset, batch_size=input_batch_size, shuffle=True
input_test_iter = DataLoader(
test_dataset, batch_size=input_batch_size, shuffle=False
return output_train_iter, input_test_iter
# 设置全局参数 / Set global parameters
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
# 初始化模型参数 / Initialize model parameters
num_inputs, num_outputs = 784, 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
def softmax(x):
计算输入x的softmax / Compute the softmax of input x.
x (torch.Tensor): 输入张量 / Input tensor.
torch.Tensor: softmax后的张量 / Softmax tensor.
x_exp = torch.exp(x - x.max(dim=1, keepdim=True).values)
partition = x_exp.sum(dim=1, keepdim=True)
return x_exp / partition
def net(x):
定义网络 / Define the network.
x (torch.Tensor): 输入张量 / Input tensor.
torch.Tensor: 网络输出 / Network output.
return softmax(torch.matmul(x.reshape((-1, W.shape[0])), W) + b)
def cross_entropy(y_hat, input_y):
交叉熵损失函数 / Cross-entropy loss function.
y_hat (torch.Tensor): 模型预测值 / Predicted values.
input_y (torch.Tensor): 实际标签 / Actual labels.
torch.Tensor: 交叉熵损失值 / Cross-entropy loss.
return -torch.log(y_hat[range(len(y_hat)), input_y] + 1e-9)
def accuracy(y_hat, input_y):
计算准确率 / Compute accuracy.
y_hat (torch.Tensor): 模型预测值 / Predicted values.
input_y (torch.Tensor): 实际标签 / Actual labels.
float: 准确率 / Accuracy.
if y_hat.ndimension() > 1:
y_hat = y_hat.argmax(dim=1)
cmp = y_hat.type(input_y.dtype) == input_y
return float(cmp.sum())
class Accumulator:
用于累加数据的类 / Class for accumulating data.
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
累加多个参数的值 / Add values of multiple arguments.
*args: 要累加的值 / Values to accumulate.
self.data = [a + float(bia) for a, bia in zip(self.data, args)]
def reset(self):
"""重置累加器 / Reset the accumulator."""
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
获取指定索引的值 / Get value at specified index.
idx (int): 索引值 / Index.
float: 累加的值 / Accumulated value.
return self.data[idx]
def evaluate_accuracy(input_net, data_iter):
评估模型在数据集上的准确率 / Evaluate model accuracy on a dataset.
input_net (callable): 网络函数 / Network function.
data_iter (DataLoader): 数据迭代器 / Data iterator.
float: 模型准确率 / Model accuracy.
metric = Accumulator(2)
for input_X, input_y in data_iter:
metric.add(accuracy(input_net(input_X), input_y), input_y.numel())
return metric[0] / metric[1] if metric[1] > 0 else 0
def train_epoch_ch3(input_net, input_train_iter, loss, input_updater):
训练模型一个epoch / Train the model for one epoch.
input_net (callable): 网络函数 / Network function.
input_train_iter (DataLoader): 训练数据迭代器 / Training data iterator.
loss (callable): 损失函数 / Loss function.
input_updater (callable): 参数更新函数 / Parameter updater.
Tuple[float, float]: 平均损失和准确率 / Average loss and accuracy.
metric = Accumulator(3)
for input_X, input_y in input_train_iter:
y_hat = input_net(input_X)
loss_value = loss(y_hat, input_y).sum()
metric.add(float(loss_value), accuracy(y_hat, input_y), input_y.numel())
return metric[0] / metric[2], metric[1] / metric[2]
lr = 1e-3 # 学习率 / Learning rate
def updater():
"""更新模型参数 / Update model parameters."""
global W, b
with torch.no_grad():
W -= lr * W.grad
b -= lr * b.grad
def train_ch3(input_net, input_train_iter, input_test_iter, loss, input_num_epochs, input_updater):
训练模型 / Train the model.
input_net (callable): 网络函数 / Network function.
input_train_iter (DataLoader): 训练数据迭代器 / Training data iterator.
input_test_iter (DataLoader): 测试数据迭代器 / Test data iterator.
loss (callable): 损失函数 / Loss function.
input_num_epochs (int): 训练轮数 / Number of training epochs.
input_updater (callable): 参数更新函数 / Parameter updater.
for epoch in range(input_num_epochs):
train_metrics = train_epoch_ch3(
input_net, input_train_iter, loss, input_updater
test_acc = evaluate_accuracy(input_net, input_test_iter)
f'epoch {epoch + 1}, loss {train_metrics[0]:.3f}, '
f'train acc {train_metrics[1]:.3f}, test acc {test_acc:.3f}'
def show_images(images, num_rows, num_cols, titles=None, scale=1.5):
显示图片 / Display images.
images (List[torch.Tensor]): 图片列表 / List of images.
num_rows (int): 行数 / Number of rows.
num_cols (int): 列数 / Number of columns.
titles (List[str], optional): 图片标题 / Titles of images. Defaults to None.
scale (float, optional): 图片缩放比例 / Scale factor. Defaults to 1.5.
figure_size = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figure_size)
axes = axes.flatten()
for i, (img, ax) in enumerate(zip(images, axes)):
if isinstance(img, torch.Tensor):
img = img.numpy()
ax.imshow(img, cmap='gray')
if titles:
# 全局变量初始化
X = None
y = None
def predict_ch3(input_net, input_test_iter, n=6):
预测并显示结果 / Predict and display results.
input_net (callable): 网络函数 / Network function.
input_test_iter (DataLoader): 测试数据迭代器 / Test data iterator.
n (int, optional): 显示图片数量 / Number of images to display. Defaults to 6.
global X, y
for X, y in input_test_iter:
trues = [str(y[i].item()) for i in range(n)]
predictions = [
str(input_net(X).argmax(dim=1)[i].item()) for i in range(n)
titles = [true + '\n' + pred for true, pred in zip(trues, predictions)]
show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles)
num_epochs = 50
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
predict_ch3(net, test_iter)
epoch 1, loss 0.9313, train acc 0.7327, test acc 0.7920
epoch 2, loss 0.6130, train acc 0.8031, test acc 0.8216
epoch 3, loss 0.5490, train acc 0.8192, test acc 0.8270
epoch 4, loss 0.5478, train acc 0.8197, test acc 0.8272
epoch 5, loss 0.5111, train acc 0.8322, test acc 0.8342
epoch 6, loss 0.5039, train acc 0.8317, test acc 0.8363
epoch 7, loss 0.5025, train acc 0.8354, test acc 0.8331
epoch 8, loss 0.4808, train acc 0.8394, test acc 0.8350
epoch 9, loss 0.4834, train acc 0.8387, test acc 0.8353
epoch 10, loss 0.4827, train acc 0.8396, test acc 0.8139
epoch 11, loss 0.4742, train acc 0.8426, test acc 0.8231
epoch 12, loss 0.4672, train acc 0.8442, test acc 0.8383
epoch 13, loss 0.4695, train acc 0.8427, test acc 0.8405
epoch 14, loss 0.4808, train acc 0.8400, test acc 0.8451
epoch 15, loss 0.4707, train acc 0.8431, test acc 0.8358
epoch 16, loss 0.4502, train acc 0.8472, test acc 0.8409
epoch 17, loss 0.4768, train acc 0.8425, test acc 0.8412
epoch 18, loss 0.4563, train acc 0.8477, test acc 0.8443
epoch 19, loss 0.4609, train acc 0.8459, test acc 0.8376
epoch 20, loss 0.4637, train acc 0.8459, test acc 0.8427
epoch 21, loss 0.4531, train acc 0.8479, test acc 0.8283
epoch 22, loss 0.4575, train acc 0.8472, test acc 0.8395
epoch 23, loss 0.4539, train acc 0.8466, test acc 0.8340
epoch 24, loss 0.4527, train acc 0.8477, test acc 0.8442
epoch 25, loss 0.4463, train acc 0.8503, test acc 0.8399
epoch 26, loss 0.4394, train acc 0.8517, test acc 0.8372
epoch 27, loss 0.4524, train acc 0.8491, test acc 0.8344
epoch 28, loss 0.4317, train acc 0.8532, test acc 0.8425
epoch 29, loss 0.4610, train acc 0.8473, test acc 0.8398
epoch 30, loss 0.4411, train acc 0.8509, test acc 0.8435
epoch 31, loss 0.4308, train acc 0.8536, test acc 0.8354
epoch 32, loss 0.4416, train acc 0.8517, test acc 0.8344
epoch 33, loss 0.4391, train acc 0.8519, test acc 0.8381
epoch 34, loss 0.4342, train acc 0.8537, test acc 0.8437
epoch 35, loss 0.4317, train acc 0.8540, test acc 0.8366
epoch 36, loss 0.4319, train acc 0.8530, test acc 0.8441
epoch 37, loss 0.4373, train acc 0.8524, test acc 0.8285
epoch 38, loss 0.4438, train acc 0.8502, test acc 0.8389
epoch 39, loss 0.4339, train acc 0.8527, test acc 0.8409
epoch 40, loss 0.4311, train acc 0.8539, test acc 0.8440
epoch 41, loss 0.4358, train acc 0.8530, test acc 0.8406
epoch 42, loss 0.4334, train acc 0.8540, test acc 0.8440
epoch 43, loss 0.4374, train acc 0.8527, test acc 0.8428
epoch 44, loss 0.4329, train acc 0.8535, test acc 0.8364
epoch 45, loss 0.4348, train acc 0.8527, test acc 0.8379
epoch 46, loss 0.4293, train acc 0.8541, test acc 0.8390
epoch 47, loss 0.4265, train acc 0.8554, test acc 0.8439
epoch 48, loss 0.4253, train acc 0.8554, test acc 0.8457
epoch 49, loss 0.4276, train acc 0.8553, test acc 0.8438
epoch 50, loss 0.4301, train acc 0.8542, test acc 0.8427
3 结果讨论
epoch 1, loss 0.9046, train acc 0.7268, test acc 0.7712
epoch 2, loss 0.6373, train acc 0.7970, test acc 0.7953
epoch 3, loss 0.5789, train acc 0.8131, test acc 0.8047
epoch 4, loss 0.5475, train acc 0.8220, test acc 0.8140
epoch 5, loss 0.5267, train acc 0.8266, test acc 0.8170
epoch 6, loss 0.5116, train acc 0.8316, test acc 0.8211
epoch 7, loss 0.5003, train acc 0.8340, test acc 0.8164
epoch 8, loss 0.4913, train acc 0.8355, test acc 0.8250
epoch 9, loss 0.4835, train acc 0.8389, test acc 0.8218
epoch 10, loss 0.4770, train acc 0.8403, test acc 0.8262
epoch 11, loss 0.4715, train acc 0.8419, test acc 0.8280
epoch 12, loss 0.4668, train acc 0.8432, test acc 0.8302
epoch 13, loss 0.4625, train acc 0.8451, test acc 0.8315
epoch 14, loss 0.4583, train acc 0.8463, test acc 0.8321
epoch 15, loss 0.4549, train acc 0.8465, test acc 0.8333
epoch 16, loss 0.4515, train acc 0.8477, test acc 0.8338
epoch 17, loss 0.4494, train acc 0.8481, test acc 0.8312
epoch 18, loss 0.4462, train acc 0.8492, test acc 0.8342
epoch 19, loss 0.4436, train acc 0.8501, test acc 0.8352
epoch 20, loss 0.4415, train acc 0.8513, test acc 0.8346
epoch 21, loss 0.4397, train acc 0.8511, test acc 0.8349
epoch 22, loss 0.4375, train acc 0.8515, test acc 0.8355
epoch 23, loss 0.4359, train acc 0.8522, test acc 0.8374
epoch 24, loss 0.4341, train acc 0.8525, test acc 0.8351
epoch 25, loss 0.4326, train acc 0.8528, test acc 0.8369
epoch 26, loss 0.4309, train acc 0.8539, test acc 0.8367
epoch 27, loss 0.4296, train acc 0.8539, test acc 0.8376
epoch 28, loss 0.4282, train acc 0.8549, test acc 0.8339
epoch 29, loss 0.4269, train acc 0.8542, test acc 0.8389
epoch 30, loss 0.4258, train acc 0.8554, test acc 0.8384
epoch 31, loss 0.4244, train acc 0.8552, test acc 0.8401
epoch 32, loss 0.4236, train acc 0.8559, test acc 0.8401
epoch 33, loss 0.4221, train acc 0.8559, test acc 0.8393
epoch 34, loss 0.4211, train acc 0.8571, test acc 0.8400
epoch 35, loss 0.4206, train acc 0.8560, test acc 0.8395
epoch 36, loss 0.4190, train acc 0.8576, test acc 0.8410
epoch 37, loss 0.4182, train acc 0.8577, test acc 0.8405
epoch 38, loss 0.4175, train acc 0.8584, test acc 0.8397
epoch 39, loss 0.4167, train acc 0.8579, test acc 0.8401
epoch 40, loss 0.4159, train acc 0.8583, test acc 0.8407
epoch 41, loss 0.4150, train acc 0.8583, test acc 0.8384
epoch 42, loss 0.4143, train acc 0.8590, test acc 0.8415
epoch 43, loss 0.4138, train acc 0.8590, test acc 0.8416
epoch 44, loss 0.4135, train acc 0.8590, test acc 0.8410
epoch 45, loss 0.4120, train acc 0.8600, test acc 0.8410
epoch 46, loss 0.4116, train acc 0.8600, test acc 0.8412
epoch 47, loss 0.4107, train acc 0.8594, test acc 0.8430
epoch 48, loss 0.4102, train acc 0.8600, test acc 0.8412
epoch 49, loss 0.4099, train acc 0.8596, test acc 0.8421
epoch 50, loss 0.4089, train acc 0.8610, test acc 0.8398
由Loss曲线可以看出,实际上这个曲线还可以再收敛于一个更小的Loss Value,且相比之前的更大的学习率的损失曲线,这一条更加平滑。由于只是一个示例,我的epoch只设置到了50。而训练的准确率有些略高于测试的准确率,有可能有过拟合行为,但在这个案例中我们不往这方面扩展,之后会扩展一下这个内容。分类的可视化结果: