文章目录
- 一、问题描述
- 二、OurModel
- 三、准备数据
- 3.1 Data Convert
- 3.2 Padding Data
- 3.3 Label Convert
- 四、双向RNN
- 五、PyTorch代码实现
- 5.1 引入相关库
- 5.2 创建Tensors函数
- 5.3 将名字转化为字符列表函数
- 5.4 国家名字数据集对象
- 5.5 RNN(GRU)分类器对象
- 5.6 训练函数
- 5.7 测试函数
- 5.8 主要代码块
- 5.9 完整代码
- 5.10 运行输出
一、问题描述
问题:根据名字,预测其所属国家
二、OurModel
三、准备数据
3.1 Data Convert
由于输入数据全是英文字符,所以可以利用ASCII码,将字符型数据转化为数值型数据
3.2 Padding Data
由于数据长短不一,所以我们需要进行Padding操作,填充0以使得所有数据长度一样
3.3 Label Convert
四、双向RNN
五、PyTorch代码实现
5.1 引入相关库
import time
import torch
import csv
import gzip
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
5.2 创建Tensors函数
def make_tensors(names, countries):
sequences_and_lengths = [name2list(name) for name in names]
name_sequences = [sl[0] for sl in sequences_and_lengths]
seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])
countries = countries.long()
seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()
for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0):
seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)
seq_tensor = seq_tensor[perm_idx]
countries = countries[perm_idx]
return torch.LongTensor(seq_tensor), \
torch.LongTensor(seq_lengths), \
torch.LongTensor(countries)
5.3 将名字转化为字符列表函数
def name2list(name):
arr = [ord(c) for c in name]
return arr, len(arr)
5.4 国家名字数据集对象
class NameDataset(Dataset):
def __init__(self, is_train_set=True):
filename = '../dataset/names_train.csv.gz' if is_train_set else '../dataset/names_test.csv.gz'
with gzip.open(filename, 'rt') as f:
reader = csv.reader(f)
rows = list(reader)
self.names = [row[0] for row in rows]
self.len = len(self.names)
self.countries = [row[1] for row in rows]
self.country_list = list(sorted(set(self.countries)))
self.country_dict = self.getCountryDict()
self.country_num = len(self.country_list)
# 获取国家字典
def getCountryDict(self):
country_dict = dict()
for idx, country_name in enumerate(self.country_list, 0):
country_dict[country_name] = idx
return country_dict
# 获取国家数量
def getCountriesNum(self):
return self.country_num
# 根据索引,返回国家的字符串
def idx2country(self, index):
return self.country_list[index]
def __getitem__(self, index):
return self.names[index], self.country_dict[self.countries[index]]
def __len__(self):
return self.len
5.5 RNN(GRU)分类器对象
class RNNClassifier(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
super(RNNClassifier, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.n_directions = 2 if bidirectional else 1
self.embedding = torch.nn.Embedding(input_size, hidden_size)
self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)
self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)
def forward(self, input, seq_len):
input = input.t() # 转置
batch_size = input.size(1)
hidden = self._init_hidden(batch_size)
embedding = self.embedding(input.to(device))
gru_input = pack_padded_sequence(embedding, seq_len)
output, hidden = self.gru(gru_input.to(device), hidden.to(device))
if self.n_directions == 2:
hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)
else:
hidden_cat = hidden[-1]
fc_output = self.fc(hidden_cat)
return fc_output
def _init_hidden(self, batch_size):
hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)
return torch.FloatTensor(hidden)
5.6 训练函数
def trainModel():
total_loss = 0
print('=' * 20, 'Epoch', epoch, '=' * 20)
for i, (names, countries) in enumerate(train_loader, 1):
inputs, seq_len, target = make_tensors(names, countries)
output = classifier(inputs, seq_len)
loss = criterion(output.to(device), target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if i % 10 == 0:
print(f'[程序已运行{time.time() - start} 秒]', end='')
print(f' - [{i * len(inputs)}/{len(train_set)}]', end='')
print(f' , loss={total_loss / (i * len(inputs))}')
5.7 测试函数
def tttModel():
correct = 0
total = len(test_set)
with torch.no_grad():
for i, (names, countries) in enumerate(test_loader, 1):
inputs, seq_len, target = make_tensors(names, countries)
output = classifier(inputs, seq_len)
pred = output.max(dim=1, keepdim=True)[1]
correct += pred.eq(target.to(device).view_as(pred)).sum().item()
percent = '%.2f' % (100 * correct / total)
print(f'在训练集上评估模型: Accuracy {correct}/{total} {percent}%')
return correct / total
5.8 主要代码块
if __name__ == '__main__':
# 参数
HIDDEN_SIZE = 100 # 隐藏层尺寸
BATCH_SIZE = 256 #
N_LAYER = 2
N_EPOCHS = 50 # 迭代次数
N_CHARS = 128 # 字符长度
USE_GPU = True # 是否启用GPU加速
# 准备数据
train_set = NameDataset(is_train_set=True)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = NameDataset(is_train_set=False)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)
# 国家数量
N_COUNTRY = train_set.getCountriesNum()
# 声明RNN模型
classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)
if USE_GPU:
device = torch.device("cuda:0")
classifier.to(device)
# 损失函数
criterion = torch.nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
start = time.time()
acc_list = []
for epoch in range(1, N_EPOCHS + 1):
trainModel()
acc = tttModel()
acc_list.append(acc)
# 画图
plt.plot([i + 1 for i in range(len(acc_list))], acc_list)
plt.show()
5.9 完整代码
import time
import torch
import csv
import gzip
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
def name2list(name):
arr = [ord(c) for c in name]
return arr, len(arr)
# 创建Tensors
def make_tensors(names, countries):
sequences_and_lengths = [name2list(name) for name in names]
name_sequences = [sl[0] for sl in sequences_and_lengths]
seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])
countries = countries.long()
seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()
for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0):
seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)
seq_tensor = seq_tensor[perm_idx]
countries = countries[perm_idx]
return torch.LongTensor(seq_tensor), \
torch.LongTensor(seq_lengths), \
torch.LongTensor(countries)
# 国家名字数据集对象
class NameDataset(Dataset):
def __init__(self, is_train_set=True):
filename = '../dataset/names_train.csv.gz' if is_train_set else '../dataset/names_test.csv.gz'
with gzip.open(filename, 'rt') as f:
reader = csv.reader(f)
rows = list(reader)
self.names = [row[0] for row in rows]
self.len = len(self.names)
self.countries = [row[1] for row in rows]
self.country_list = list(sorted(set(self.countries)))
self.country_dict = self.getCountryDict()
self.country_num = len(self.country_list)
# 获取国家字典
def getCountryDict(self):
country_dict = dict()
for idx, country_name in enumerate(self.country_list, 0):
country_dict[country_name] = idx
return country_dict
# 获取国家数量
def getCountriesNum(self):
return self.country_num
# 根据索引,返回国家的字符串
def idx2country(self, index):
return self.country_list[index]
def __getitem__(self, index):
return self.names[index], self.country_dict[self.countries[index]]
def __len__(self):
return self.len
# RNN分类器对象
class RNNClassifier(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
super(RNNClassifier, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.n_directions = 2 if bidirectional else 1
self.embedding = torch.nn.Embedding(input_size, hidden_size)
self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)
self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)
def forward(self, input, seq_len):
input = input.t() # 转置
batch_size = input.size(1)
hidden = self._init_hidden(batch_size)
embedding = self.embedding(input.to(device))
gru_input = pack_padded_sequence(embedding, seq_len)
output, hidden = self.gru(gru_input.to(device), hidden.to(device))
if self.n_directions == 2:
hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)
else:
hidden_cat = hidden[-1]
fc_output = self.fc(hidden_cat)
return fc_output
def _init_hidden(self, batch_size):
hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)
return torch.FloatTensor(hidden)
# 训练函数
def trainModel():
total_loss = 0
print('=' * 20, 'Epoch', epoch, '=' * 20)
for i, (names, countries) in enumerate(train_loader, 1):
inputs, seq_len, target = make_tensors(names, countries)
output = classifier(inputs, seq_len)
loss = criterion(output.to(device), target.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if i % 10 == 0:
print(f'[程序已运行{time.time() - start} 秒]', end='')
print(f' - [{i * len(inputs)}/{len(train_set)}]', end='')
print(f' , loss={total_loss / (i * len(inputs))}')
# 测试函数
def tttModel():
correct = 0
total = len(test_set)
with torch.no_grad():
for i, (names, countries) in enumerate(test_loader, 1):
inputs, seq_len, target = make_tensors(names, countries)
output = classifier(inputs, seq_len)
pred = output.max(dim=1, keepdim=True)[1]
correct += pred.eq(target.to(device).view_as(pred)).sum().item()
percent = '%.2f' % (100 * correct / total)
print(f'在训练集上评估模型: Accuracy {correct}/{total} {percent}%')
return correct / total
if __name__ == '__main__':
# 参数
HIDDEN_SIZE = 100 # 隐藏层尺寸
BATCH_SIZE = 256 #
N_LAYER = 2
N_EPOCHS = 50 # 迭代次数
N_CHARS = 128 # 字符长度
USE_GPU = True # 是否启用GPU加速
# 准备数据
train_set = NameDataset(is_train_set=True)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = NameDataset(is_train_set=False)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)
# 国家数量
N_COUNTRY = train_set.getCountriesNum()
# 声明RNN模型
classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)
if USE_GPU:
device = torch.device("cuda:0")
classifier.to(device)
# 损失函数
criterion = torch.nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
start = time.time()
acc_list = []
for epoch in range(1, N_EPOCHS + 1):
trainModel()
acc = tttModel()
acc_list.append(acc)
# 画图
plt.plot([i + 1 for i in range(len(acc_list))], acc_list)
plt.show()
5.10 运行输出
模型在测试集上的正确率迭代图:
控制台输出:
==================== Epoch 1 ====================
[程序已运行0.44780421257019043 秒] - [2560/13374] , loss=0.00890564052388072
[程序已运行0.6133613586425781 秒] - [5120/13374] , loss=0.007599683245643973
[程序已运行0.7878952026367188 秒] - [7680/13374] , loss=0.006917052948847413
[程序已运行0.9584388732910156 秒] - [10240/13374] , loss=0.006446240993682295
[程序已运行1.1319754123687744 秒] - [12800/13374] , loss=0.0060893167182803154
在训练集上评估模型: Accuracy 4472/6700 66.75%
==================== Epoch 2 ====================
[程序已运行1.5678095817565918 秒] - [2560/13374] , loss=0.004120685928501189
[程序已运行1.7363591194152832 秒] - [5120/13374] , loss=0.004010635381564498
[程序已运行1.9069037437438965 秒] - [7680/13374] , loss=0.00399712462288638
[程序已运行2.0754520893096924 秒] - [10240/13374] , loss=0.003917965857544914
[程序已运行2.2400126457214355 秒] - [12800/13374] , loss=0.0038193828100338578
在训练集上评估模型: Accuracy 4984/6700 74.39%
==================== Epoch 3 ====================
[程序已运行2.7569212913513184 秒] - [2560/13374] , loss=0.003344046091660857
[程序已运行2.928469181060791 秒] - [5120/13374] , loss=0.003259772143792361
[程序已运行3.1019983291625977 秒] - [7680/13374] , loss=0.003187967735963563
[程序已运行3.274538040161133 秒] - [10240/13374] , loss=0.0031107491464354097
[程序已运行3.442089319229126 秒] - [12800/13374] , loss=0.0030547570576891303
在训练集上评估模型: Accuracy 5251/6700 78.37%
==================== Epoch 4 ====================
[程序已运行3.8719401359558105 秒] - [2560/13374] , loss=0.0026993038831278683
[程序已运行4.046473503112793 秒] - [5120/13374] , loss=0.002674128650687635
[程序已运行4.216020822525024 秒] - [7680/13374] , loss=0.0026587502487624686
[程序已运行4.38556694984436 秒] - [10240/13374] , loss=0.0025898678810335695
[程序已运行4.550126552581787 秒] - [12800/13374] , loss=0.0025764494528993966
在训练集上评估模型: Accuracy 5364/6700 80.06%
==================== Epoch 5 ====================
[程序已运行4.99593448638916 秒] - [2560/13374] , loss=0.002204193570651114
[程序已运行5.161492109298706 秒] - [5120/13374] , loss=0.002289206086425111
[程序已运行5.331039667129517 秒] - [7680/13374] , loss=0.0022666183416731656
[程序已运行5.49859094619751 秒] - [10240/13374] , loss=0.002261629086569883
[程序已运行5.671129941940308 秒] - [12800/13374] , loss=0.0022295594890601933
在训练集上评估模型: Accuracy 5463/6700 81.54%
==================== Epoch 6 ====================
[程序已运行6.233642578125 秒] - [2560/13374] , loss=0.0019896522513590752
[程序已运行6.456540107727051 秒] - [5120/13374] , loss=0.001986617426155135
[程序已运行6.664837121963501 秒] - [7680/13374] , loss=0.0019825143235114714
[程序已运行6.833385467529297 秒] - [10240/13374] , loss=0.0020075739128515126
[程序已运行7.032860040664673 秒] - [12800/13374] , loss=0.0020189793314784764
在训练集上评估模型: Accuracy 5509/6700 82.22%
==================== Epoch 7 ====================
[程序已运行7.519379138946533 秒] - [2560/13374] , loss=0.0018858694704249502
[程序已运行7.713961362838745 秒] - [5120/13374] , loss=0.0019061228667851537
[程序已运行7.924145698547363 秒] - [7680/13374] , loss=0.0018647492048330604
[程序已运行8.105660200119019 秒] - [10240/13374] , loss=0.0018422425084281713
[程序已运行8.272214412689209 秒] - [12800/13374] , loss=0.001829312415793538
在训练集上评估模型: Accuracy 5554/6700 82.90%
==================== Epoch 8 ====================
[程序已运行8.776322841644287 秒] - [2560/13374] , loss=0.0017062094528228044
[程序已运行8.956839084625244 秒] - [5120/13374] , loss=0.0017181114875711502
[程序已运行9.128380060195923 秒] - [7680/13374] , loss=0.0016691674555962285
[程序已运行9.299922227859497 秒] - [10240/13374] , loss=0.001663037418620661
[程序已运行9.480438709259033 秒] - [12800/13374] , loss=0.0016579296253621577
在训练集上评估模型: Accuracy 5622/6700 83.91%
==================== Epoch 9 ====================
[程序已运行10.042935371398926 秒] - [2560/13374] , loss=0.00138091582339257
[程序已运行10.261351108551025 秒] - [5120/13374] , loss=0.0014375038270372897
[程序已运行10.564762353897095 秒] - [7680/13374] , loss=0.0014615616644732654
[程序已运行10.782928466796875 秒] - [10240/13374] , loss=0.0014871433144435287
[程序已运行10.983396530151367 秒] - [12800/13374] , loss=0.001485794959589839
在训练集上评估模型: Accuracy 5639/6700 84.16%
==================== Epoch 10 ====================
[程序已运行11.527957201004028 秒] - [2560/13374] , loss=0.0013533846475183963
[程序已运行11.727681636810303 秒] - [5120/13374] , loss=0.0013366769824642688
[程序已运行11.922213315963745 秒] - [7680/13374] , loss=0.0013372401415836066
[程序已运行12.146097660064697 秒] - [10240/13374] , loss=0.0013334597359062172
[程序已运行12.360528945922852 秒] - [12800/13374] , loss=0.0013430006837006658
在训练集上评估模型: Accuracy 5671/6700 84.64%
==================== Epoch 11 ====================
[程序已运行12.899601221084595 秒] - [2560/13374] , loss=0.0012242367956787348
[程序已运行13.09657597541809 秒] - [5120/13374] , loss=0.0012018651730613783
[程序已运行13.297226905822754 秒] - [7680/13374] , loss=0.0011939232140624275
[程序已运行13.494091510772705 秒] - [10240/13374] , loss=0.0012114218538044953
[程序已运行13.68766736984253 秒] - [12800/13374] , loss=0.0012003930215723812
在训练集上评估模型: Accuracy 5694/6700 84.99%
==================== Epoch 12 ====================
[程序已运行14.210429906845093 秒] - [2560/13374] , loss=0.0009987134893890471
[程序已运行14.412419557571411 秒] - [5120/13374] , loss=0.001071632924140431
[程序已运行14.66275668144226 秒] - [7680/13374] , loss=0.001076526817632839
[程序已运行14.856532573699951 秒] - [10240/13374] , loss=0.0010921183042228223
[程序已运行15.086088180541992 秒] - [12800/13374] , loss=0.0010934956604614853
在训练集上评估模型: Accuracy 5662/6700 84.51%
==================== Epoch 13 ====================
[程序已运行15.603407859802246 秒] - [2560/13374] , loss=0.0009401162387803197
[程序已运行15.79611325263977 秒] - [5120/13374] , loss=0.000941314865485765
[程序已运行15.991790294647217 秒] - [7680/13374] , loss=0.0009610804088879376
[程序已运行16.179932117462158 秒] - [10240/13374] , loss=0.0009680362229119055
[程序已运行16.40265130996704 秒] - [12800/13374] , loss=0.000982134909136221
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 14 ====================
[程序已运行16.909828424453735 秒] - [2560/13374] , loss=0.0007867097272537648
[程序已运行17.106796979904175 秒] - [5120/13374] , loss=0.000815554044675082
[程序已运行17.292539358139038 秒] - [7680/13374] , loss=0.0008367548192230364
[程序已运行17.472803592681885 秒] - [10240/13374] , loss=0.0008467563966405578
[程序已运行17.639357805252075 秒] - [12800/13374] , loss=0.0008704598969779909
在训练集上评估模型: Accuracy 5668/6700 84.60%
==================== Epoch 15 ====================
[程序已运行18.066723585128784 秒] - [2560/13374] , loss=0.0007090812985552474
[程序已运行18.24724054336548 秒] - [5120/13374] , loss=0.0007125684132915922
[程序已运行18.403822898864746 秒] - [7680/13374] , loss=0.0007363049614165599
[程序已运行18.56738543510437 秒] - [10240/13374] , loss=0.0007670420709473547
[程序已运行18.725961923599243 秒] - [12800/13374] , loss=0.0007662988040829078
在训练集上评估模型: Accuracy 5691/6700 84.94%
==================== Epoch 16 ====================
[程序已运行19.231609582901 秒] - [2560/13374] , loss=0.0006500366900581867
[程序已运行19.412126302719116 秒] - [5120/13374] , loss=0.0006256612250581384
[程序已运行19.586659908294678 秒] - [7680/13374] , loss=0.0006333356761994461
[程序已运行19.75919795036316 秒] - [10240/13374] , loss=0.0006514595239423216
[程序已运行19.923759937286377 秒] - [12800/13374] , loss=0.0006744604662526399
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 17 ====================
[程序已运行20.44287872314453 秒] - [2560/13374] , loss=0.0005562761740293354
[程序已运行20.62125062942505 秒] - [5120/13374] , loss=0.0005988578370306641
[程序已运行20.796781301498413 秒] - [7680/13374] , loss=0.0005926105329611649
[程序已运行20.967325687408447 秒] - [10240/13374] , loss=0.000587682421610225
[程序已运行21.13288402557373 秒] - [12800/13374] , loss=0.0005947937141172587
在训练集上评估模型: Accuracy 5690/6700 84.93%
==================== Epoch 18 ====================
[程序已运行21.58068537712097 秒] - [2560/13374] , loss=0.0004808117635548115
[程序已运行21.750232458114624 秒] - [5120/13374] , loss=0.0004957637938787229
[程序已运行21.91977834701538 秒] - [7680/13374] , loss=0.0005180548294447362
[程序已运行22.09730362892151 秒] - [10240/13374] , loss=0.0005155952792847529
[程序已运行22.261864185333252 秒] - [12800/13374] , loss=0.000532688939711079
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 19 ====================
[程序已运行22.743576526641846 秒] - [2560/13374] , loss=0.00045145115873310713
[程序已运行22.916114330291748 秒] - [5120/13374] , loss=0.0004681079401052557
[程序已运行23.086658239364624 秒] - [7680/13374] , loss=0.0004730896607118969
[程序已运行23.260194301605225 秒] - [10240/13374] , loss=0.00047693301239633
[程序已运行23.420764923095703 秒] - [12800/13374] , loss=0.00048351394420024005
在训练集上评估模型: Accuracy 5686/6700 84.87%
==================== Epoch 20 ====================
[程序已运行23.879538536071777 秒] - [2560/13374] , loss=0.000395867633051239
[程序已运行24.058061122894287 秒] - [5120/13374] , loss=0.00040292235498782245
[程序已运行24.22611951828003 秒] - [7680/13374] , loss=0.00041214636488196753
[程序已运行24.40364646911621 秒] - [10240/13374] , loss=0.0004200873590889387
[程序已运行24.599121570587158 秒] - [12800/13374] , loss=0.00042729771870654077
在训练集上评估模型: Accuracy 5683/6700 84.82%
==================== Epoch 21 ====================
[程序已运行25.093798637390137 秒] - [2560/13374] , loss=0.0003232844435842708
[程序已运行25.266338109970093 秒] - [5120/13374] , loss=0.0003470184303296264
[程序已运行25.42990016937256 秒] - [7680/13374] , loss=0.00036114812683081255
[程序已运行25.605431079864502 秒] - [10240/13374] , loss=0.00036963203929190057
[程序已运行25.77896738052368 秒] - [12800/13374] , loss=0.00038419733085902406
在训练集上评估模型: Accuracy 5656/6700 84.42%
==================== Epoch 22 ====================
[程序已运行26.23973536491394 秒] - [2560/13374] , loss=0.00035485800908645614
[程序已运行26.42025327682495 秒] - [5120/13374] , loss=0.000331850739166839
[程序已运行26.59478497505188 秒] - [7680/13374] , loss=0.00033587128127692266
[程序已运行26.76333522796631 秒] - [10240/13374] , loss=0.0003520488495269092
[程序已运行26.939862489700317 秒] - [12800/13374] , loss=0.00036097451025852935
在训练集上评估模型: Accuracy 5633/6700 84.07%
==================== Epoch 23 ====================
[程序已运行27.467453002929688 秒] - [2560/13374] , loss=0.0003026763326488435
[程序已运行27.70481777191162 秒] - [5120/13374] , loss=0.0003143761219689623
[程序已运行27.911266088485718 秒] - [7680/13374] , loss=0.0003264661878347397
[程序已运行28.186530113220215 秒] - [10240/13374] , loss=0.0003306483005871996
[程序已运行28.46631908416748 秒] - [12800/13374] , loss=0.0003335028060246259
在训练集上评估模型: Accuracy 5673/6700 84.67%
==================== Epoch 24 ====================
[程序已运行29.179603815078735 秒] - [2560/13374] , loss=0.00027135475975228476
[程序已运行29.430758237838745 秒] - [5120/13374] , loss=0.00028713325373246333
[程序已运行29.679874897003174 秒] - [7680/13374] , loss=0.0003023453294493568
[程序已运行29.934645414352417 秒] - [10240/13374] , loss=0.000312889136330341
[程序已运行30.228501319885254 秒] - [12800/13374] , loss=0.0003050918216467835
在训练集上评估模型: Accuracy 5648/6700 84.30%
==================== Epoch 25 ====================
[程序已运行30.873358488082886 秒] - [2560/13374] , loss=0.00023736636503599585
[程序已运行31.12848401069641 秒] - [5120/13374] , loss=0.000262370355630992
[程序已运行31.37047266960144 秒] - [7680/13374] , loss=0.00028469724299308533
[程序已运行31.587894439697266 秒] - [10240/13374] , loss=0.00029130332404747605
[程序已运行31.81080985069275 秒] - [12800/13374] , loss=0.000298141545499675
在训练集上评估模型: Accuracy 5650/6700 84.33%
==================== Epoch 26 ====================
[程序已运行32.386191606521606 秒] - [2560/13374] , loss=0.0002229748701211065
[程序已运行32.59639310836792 秒] - [5120/13374] , loss=0.00024768941948423163
[程序已运行32.80773115158081 秒] - [7680/13374] , loss=0.0002617782947102872
[程序已运行33.03980994224548 秒] - [10240/13374] , loss=0.0002768368027318502
[程序已运行33.29946231842041 秒] - [12800/13374] , loss=0.0002899381099268794
在训练集上评估模型: Accuracy 5641/6700 84.19%
==================== Epoch 27 ====================
[程序已运行33.88727021217346 秒] - [2560/13374] , loss=0.00024374632048420608
[程序已运行34.09471583366394 秒] - [5120/13374] , loss=0.00025897946034092456
[程序已运行34.289196252822876 秒] - [7680/13374] , loss=0.0002538432978326455
[程序已运行34.47469925880432 秒] - [10240/13374] , loss=0.000259564047519234
[程序已运行34.67316937446594 秒] - [12800/13374] , loss=0.00027740672609070314
在训练集上评估模型: Accuracy 5646/6700 84.27%
==================== Epoch 28 ====================
[程序已运行35.17283344268799 秒] - [2560/13374] , loss=0.00024596738221589477
[程序已运行35.37529134750366 秒] - [5120/13374] , loss=0.00025625340931583195
[程序已运行35.612366676330566 秒] - [7680/13374] , loss=0.00025634167104726657
[程序已运行35.83150577545166 秒] - [10240/13374] , loss=0.00026004306309914684
[程序已运行36.03250527381897 秒] - [12800/13374] , loss=0.00026288528344593944
在训练集上评估模型: Accuracy 5638/6700 84.15%
==================== Epoch 29 ====================
[程序已运行36.57138133049011 秒] - [2560/13374] , loss=0.00021622761414619163
[程序已运行36.77258634567261 秒] - [5120/13374] , loss=0.00022565958206541837
[程序已运行36.98577690124512 秒] - [7680/13374] , loss=0.0002272359988031288
[程序已运行37.19248104095459 秒] - [10240/13374] , loss=0.000239311121913488
[程序已运行37.42102932929993 秒] - [12800/13374] , loss=0.00026318065298255535
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 30 ====================
[程序已运行37.98303580284119 秒] - [2560/13374] , loss=0.0002419873373582959
[程序已运行38.19469690322876 秒] - [5120/13374] , loss=0.00024032749206526205
[程序已运行38.433921813964844 秒] - [7680/13374] , loss=0.00023429312568623573
[程序已运行38.63934087753296 秒] - [10240/13374] , loss=0.000236004658654565
[程序已运行38.855393171310425 秒] - [12800/13374] , loss=0.00024584917380707336
在训练集上评估模型: Accuracy 5652/6700 84.36%
==================== Epoch 31 ====================
[程序已运行39.43484330177307 秒] - [2560/13374] , loss=0.00018982139372383244
[程序已运行39.639156103134155 秒] - [5120/13374] , loss=0.0001874874855275266
[程序已运行39.86286163330078 秒] - [7680/13374] , loss=0.0002010059698174397
[程序已运行40.10820508003235 秒] - [10240/13374] , loss=0.00022044739162083716
[程序已运行40.321282148361206 秒] - [12800/13374] , loss=0.00023711820482276381
在训练集上评估模型: Accuracy 5650/6700 84.33%
==================== Epoch 32 ====================
[程序已运行40.8946692943573 秒] - [2560/13374] , loss=0.00015288349750335327
[程序已运行41.10710620880127 秒] - [5120/13374] , loss=0.0001839709879277507
[程序已运行41.33549523353577 秒] - [7680/13374] , loss=0.00019845828088970545
[程序已运行41.57286047935486 秒] - [10240/13374] , loss=0.00021456888480315684
[程序已运行41.8036105632782 秒] - [12800/13374] , loss=0.0002269831285229884
在训练集上评估模型: Accuracy 5666/6700 84.57%
==================== Epoch 33 ====================
[程序已运行42.351365089416504 秒] - [2560/13374] , loss=0.00020009858708363025
[程序已运行42.55391049385071 秒] - [5120/13374] , loss=0.00021211478597251698
[程序已运行42.762874364852905 秒] - [7680/13374] , loss=0.00021397633778785045
[程序已运行42.9832968711853 秒] - [10240/13374] , loss=0.0002237053031421965
[程序已运行43.203397274017334 秒] - [12800/13374] , loss=0.00022588698309846221
在训练集上评估模型: Accuracy 5663/6700 84.52%
==================== Epoch 34 ====================
[程序已运行43.75180983543396 秒] - [2560/13374] , loss=0.000175720240076771
[程序已运行43.96774888038635 秒] - [5120/13374] , loss=0.0001822654259740375
[程序已运行44.17185306549072 秒] - [7680/13374] , loss=0.00019087508820424167
[程序已运行44.37160062789917 秒] - [10240/13374] , loss=0.0002025523937845719
[程序已运行44.58753275871277 秒] - [12800/13374] , loss=0.0002117827812617179
在训练集上评估模型: Accuracy 5661/6700 84.49%
==================== Epoch 35 ====================
[程序已运行45.14763021469116 秒] - [2560/13374] , loss=0.00018244399325340055
[程序已运行45.38248586654663 秒] - [5120/13374] , loss=0.0001909391281515127
[程序已运行45.58654308319092 秒] - [7680/13374] , loss=0.00019028932996055422
[程序已运行45.78802943229675 秒] - [10240/13374] , loss=0.00019807684529951075
[程序已运行45.99347996711731 秒] - [12800/13374] , loss=0.00020920663388096728
在训练集上评估模型: Accuracy 5680/6700 84.78%
==================== Epoch 36 ====================
[程序已运行46.53318977355957 秒] - [2560/13374] , loss=0.0002301990520209074
[程序已运行46.713706970214844 秒] - [5120/13374] , loss=0.00019428682790021413
[程序已运行46.90220332145691 秒] - [7680/13374] , loss=0.00020380564043686415
[程序已运行47.091697454452515 秒] - [10240/13374] , loss=0.00020663877294282428
[程序已运行47.291162967681885 秒] - [12800/13374] , loss=0.0002089952616370283
在训练集上评估模型: Accuracy 5676/6700 84.72%
==================== Epoch 37 ====================
[程序已运行47.806785345077515 秒] - [2560/13374] , loss=0.0001326974183029961
[程序已运行47.99926948547363 秒] - [5120/13374] , loss=0.00016608303412795066
[程序已运行48.17978763580322 秒] - [7680/13374] , loss=0.00018860103252033394
[程序已运行48.39720582962036 秒] - [10240/13374] , loss=0.00019829184930131305
[程序已运行48.613649129867554 秒] - [12800/13374] , loss=0.00020632889354601502
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 38 ====================
[程序已运行49.16527962684631 秒] - [2560/13374] , loss=0.00017558708641445263
[程序已运行49.377071380615234 秒] - [5120/13374] , loss=0.0001803442672098754
[程序已运行49.61643576622009 秒] - [7680/13374] , loss=0.00019150104708387516
[程序已运行49.82301330566406 秒] - [10240/13374] , loss=0.00020083691142644967
[程序已运行50.0354540348053 秒] - [12800/13374] , loss=0.00020664259296609088
在训练集上评估模型: Accuracy 5665/6700 84.55%
==================== Epoch 39 ====================
[程序已运行50.57412886619568 秒] - [2560/13374] , loss=0.0001647219163714908
[程序已运行50.781656980514526 秒] - [5120/13374] , loss=0.00017860122316051274
[程序已运行50.985915660858154 秒] - [7680/13374] , loss=0.00018149100748511652
[程序已运行51.19026255607605 秒] - [10240/13374] , loss=0.00018654396444617305
[程序已运行51.396902322769165 秒] - [12800/13374] , loss=0.00019622410647571087
在训练集上评估模型: Accuracy 5674/6700 84.69%
==================== Epoch 40 ====================
[程序已运行51.950305223464966 秒] - [2560/13374] , loss=0.0001813760813092813
[程序已运行52.168620586395264 秒] - [5120/13374] , loss=0.0002006525617616717
[程序已运行52.388041496276855 秒] - [7680/13374] , loss=0.00020146279275650157
[程序已运行52.609692335128784 秒] - [10240/13374] , loss=0.00020502966435742564
[程序已运行52.82223105430603 秒] - [12800/13374] , loss=0.00020634495449485258
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 41 ====================
[程序已运行53.38243508338928 秒] - [2560/13374] , loss=0.00015071030429680832
[程序已运行53.587127923965454 秒] - [5120/13374] , loss=0.00017189887112181168
[程序已运行53.806541204452515 秒] - [7680/13374] , loss=0.00017305417253131358
[程序已运行54.02480912208557 秒] - [10240/13374] , loss=0.00019079587727901525
[程序已运行54.2585232257843 秒] - [12800/13374] , loss=0.00020163424400379882
在训练集上评估模型: Accuracy 5648/6700 84.30%
==================== Epoch 42 ====================
[程序已运行54.82281470298767 秒] - [2560/13374] , loss=0.00013891297894588205
[程序已运行55.025827407836914 秒] - [5120/13374] , loss=0.00015851605949137592
[程序已运行55.231240034103394 秒] - [7680/13374] , loss=0.00016577779958121635
[程序已运行55.46313452720642 秒] - [10240/13374] , loss=0.00018027596570391324
[程序已运行55.67412066459656 秒] - [12800/13374] , loss=0.0001966867937153438
在训练集上评估模型: Accuracy 5665/6700 84.55%
==================== Epoch 43 ====================
[程序已运行56.20661902427673 秒] - [2560/13374] , loss=0.00015996379297575914
[程序已运行56.41858148574829 秒] - [5120/13374] , loss=0.00016856673537404276
[程序已运行56.639917612075806 秒] - [7680/13374] , loss=0.00018089073128066958
[程序已运行56.8443865776062 秒] - [10240/13374] , loss=0.0001840785967942793
[程序已运行57.06223940849304 秒] - [12800/13374] , loss=0.00019384760729735718
在训练集上评估模型: Accuracy 5645/6700 84.25%
==================== Epoch 44 ====================
[程序已运行57.572874546051025 秒] - [2560/13374] , loss=0.00015061586382216773
[程序已运行57.78630328178406 秒] - [5120/13374] , loss=0.0001388985794619657
[程序已运行57.9957435131073 秒] - [7680/13374] , loss=0.00016478789621032774
[程序已运行58.19321537017822 秒] - [10240/13374] , loss=0.00017915765802172244
[程序已运行58.42410707473755 秒] - [12800/13374] , loss=0.00018903483942267485
在训练集上评估模型: Accuracy 5671/6700 84.64%
==================== Epoch 45 ====================
[程序已运行58.94371771812439 秒] - [2560/13374] , loss=0.00015125685386010445
[程序已运行59.1222403049469 秒] - [5120/13374] , loss=0.0001521789192338474
[程序已运行59.306747913360596 秒] - [7680/13374] , loss=0.00016950705078973746
[程序已运行59.48726511001587 秒] - [10240/13374] , loss=0.00017524105423945003
[程序已运行59.67775559425354 秒] - [12800/13374] , loss=0.0001805883324414026
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 46 ====================
[程序已运行60.20154929161072 秒] - [2560/13374] , loss=0.00014828720522928052
[程序已运行60.40001821517944 秒] - [5120/13374] , loss=0.00015882042025623378
[程序已运行60.58452534675598 秒] - [7680/13374] , loss=0.0001613501580626083
[程序已运行60.77501559257507 秒] - [10240/13374] , loss=0.00016973716010397765
[程序已运行60.948551416397095 秒] - [12800/13374] , loss=0.0001806799619225785
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 47 ====================
[程序已运行61.3913676738739 秒] - [2560/13374] , loss=0.00015815104998182506
[程序已运行61.57487750053406 秒] - [5120/13374] , loss=0.00016703655710443855
[程序已运行61.75738835334778 秒] - [7680/13374] , loss=0.00017824470463286465
[程序已运行61.92992830276489 秒] - [10240/13374] , loss=0.00017852899200079264
[程序已运行62.09648156166077 秒] - [12800/13374] , loss=0.00018316529472940603
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 48 ====================
[程序已运行62.557249307632446 秒] - [2560/13374] , loss=0.00015278960127034223
[程序已运行62.73379874229431 秒] - [5120/13374] , loss=0.00015682993143855127
[程序已运行62.920299768447876 秒] - [7680/13374] , loss=0.00016720300166828868
[程序已运行63.097824573516846 秒] - [10240/13374] , loss=0.00017679908323771089
[程序已运行63.27834177017212 秒] - [12800/13374] , loss=0.00018408266638289205
在训练集上评估模型: Accuracy 5632/6700 84.06%
==================== Epoch 49 ====================
[程序已运行63.809921741485596 秒] - [2560/13374] , loss=0.00013651168810611126
[程序已运行63.99043846130371 秒] - [5120/13374] , loss=0.00015391943106806137
[程序已运行64.16397380828857 秒] - [7680/13374] , loss=0.00016998272246079675
[程序已运行64.34748315811157 秒] - [10240/13374] , loss=0.0001738338686664065
[程序已运行64.52999496459961 秒] - [12800/13374] , loss=0.0001867358752497239
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 50 ====================
[程序已运行64.99973917007446 秒] - [2560/13374] , loss=0.00014780706333112904
[程序已运行65.21316909790039 秒] - [5120/13374] , loss=0.00015933765134832357
[程序已运行65.40266251564026 秒] - [7680/13374] , loss=0.0001664065877169681
[程序已运行65.57220840454102 秒] - [10240/13374] , loss=0.00017008764552883804
[程序已运行65.75272583961487 秒] - [12800/13374] , loss=0.00018112511475919745
在训练集上评估模型: Accuracy 5646/6700 84.27%