因为图片识别很多代码都包装在d2l库里了,直接调用就行了
完整代码:
import torch
from torch import nn
from d2l import torch as d2l
"获取训练集&获取检测集"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) # nn.Flatten()将28*28展平成784
"初始化w,b后者不操作默认初始化"
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std = 0.01)
net.apply(init_weights) # 给到所有模型
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1) # net.parameters()将net中数据整合w,b给SGD
if __name__ == '__main__':
num_epochs = 10
cnt = 1
for i in range(num_epochs):
X, Y = d2l.train_epoch_ch3(net, train_iter, loss, trainer)
print("训练次数: " + str(cnt))
cnt += 1
print("训练损失: {:.4f}".format(X))
print("训练精度: {:.4f}".format(Y))
print(".................................")
画图功能不兼容pycharm,所以还是朴素的用输出函数吧