【深度学习】PyTorch深度学习实践 - Lecture_13_RNN_Classifier

news2024/11/25 3:02:01

文章目录

  • 一、问题描述
  • 二、OurModel
  • 三、准备数据
    • 3.1 Data Convert
    • 3.2 Padding Data
    • 3.3 Label Convert
  • 四、双向RNN
  • 五、PyTorch代码实现
    • 5.1 引入相关库
    • 5.2 创建Tensors函数
    • 5.3 将名字转化为字符列表函数
    • 5.4 国家名字数据集对象
    • 5.5 RNN(GRU)分类器对象
    • 5.6 训练函数
    • 5.7 测试函数
    • 5.8 主要代码块
    • 5.9 完整代码
    • 5.10 运行输出


一、问题描述

问题:根据名字,预测其所属国家
在这里插入图片描述

二、OurModel

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

三、准备数据

3.1 Data Convert

由于输入数据全是英文字符,所以可以利用ASCII码,将字符型数据转化为数值型数据
在这里插入图片描述

3.2 Padding Data

由于数据长短不一,所以我们需要进行Padding操作,填充0以使得所有数据长度一样
在这里插入图片描述

3.3 Label Convert

在这里插入图片描述

四、双向RNN

在这里插入图片描述

五、PyTorch代码实现

5.1 引入相关库

import time
import torch
import csv
import gzip
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

5.2 创建Tensors函数

def make_tensors(names, countries):
    sequences_and_lengths = [name2list(name) for name in names]
    name_sequences = [sl[0] for sl in sequences_and_lengths]
    seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])
    countries = countries.long()

    seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()
    for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0):
        seq_tensor[idx, :seq_len] = torch.LongTensor(seq)

    seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)
    seq_tensor = seq_tensor[perm_idx]
    countries = countries[perm_idx]
    return torch.LongTensor(seq_tensor), \
           torch.LongTensor(seq_lengths), \
           torch.LongTensor(countries)

5.3 将名字转化为字符列表函数

def name2list(name):
    arr = [ord(c) for c in name]
    return arr, len(arr)

5.4 国家名字数据集对象

class NameDataset(Dataset):
    def __init__(self, is_train_set=True):
        filename = '../dataset/names_train.csv.gz' if is_train_set else '../dataset/names_test.csv.gz'
        with gzip.open(filename, 'rt') as f:
            reader = csv.reader(f)
            rows = list(reader)
            self.names = [row[0] for row in rows]
            self.len = len(self.names)
            self.countries = [row[1] for row in rows]
            self.country_list = list(sorted(set(self.countries)))
            self.country_dict = self.getCountryDict()
            self.country_num = len(self.country_list)

    # 获取国家字典
    def getCountryDict(self):
        country_dict = dict()
        for idx, country_name in enumerate(self.country_list, 0):
            country_dict[country_name] = idx
        return country_dict

    # 获取国家数量
    def getCountriesNum(self):
        return self.country_num

    # 根据索引,返回国家的字符串
    def idx2country(self, index):
        return self.country_list[index]

    def __getitem__(self, index):
        return self.names[index], self.country_dict[self.countries[index]]

    def __len__(self):
        return self.len

5.5 RNN(GRU)分类器对象

class RNNClassifier(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
        super(RNNClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.n_directions = 2 if bidirectional else 1

        self.embedding = torch.nn.Embedding(input_size, hidden_size)
        self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)

        self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)

    def forward(self, input, seq_len):
        input = input.t()  # 转置
        batch_size = input.size(1)

        hidden = self._init_hidden(batch_size)
        embedding = self.embedding(input.to(device))

        gru_input = pack_padded_sequence(embedding, seq_len)

        output, hidden = self.gru(gru_input.to(device), hidden.to(device))
        if self.n_directions == 2:
            hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)
        else:
            hidden_cat = hidden[-1]

        fc_output = self.fc(hidden_cat)
        return fc_output

    def _init_hidden(self, batch_size):
        hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)
        return torch.FloatTensor(hidden)

5.6 训练函数

def trainModel():
    total_loss = 0
    print('=' * 20, 'Epoch', epoch, '=' * 20)
    for i, (names, countries) in enumerate(train_loader, 1):
        inputs, seq_len, target = make_tensors(names, countries)
        output = classifier(inputs, seq_len)
        loss = criterion(output.to(device), target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 10 == 0:
            print(f'[程序已运行{time.time() - start} 秒]', end='')
            print(f' - [{i * len(inputs)}/{len(train_set)}]', end='')
            print(f' , loss={total_loss / (i * len(inputs))}')

5.7 测试函数

def tttModel():
    correct = 0
    total = len(test_set)
    with torch.no_grad():
        for i, (names, countries) in enumerate(test_loader, 1):
            inputs, seq_len, target = make_tensors(names, countries)
            output = classifier(inputs, seq_len)
            pred = output.max(dim=1, keepdim=True)[1]
            correct += pred.eq(target.to(device).view_as(pred)).sum().item()
        percent = '%.2f' % (100 * correct / total)
        print(f'在训练集上评估模型: Accuracy {correct}/{total} {percent}%')
    return correct / total

5.8 主要代码块

if __name__ == '__main__':
    # 参数
    HIDDEN_SIZE = 100  # 隐藏层尺寸
    BATCH_SIZE = 256  #
    N_LAYER = 2
    N_EPOCHS = 50  # 迭代次数
    N_CHARS = 128  # 字符长度
    USE_GPU = True  # 是否启用GPU加速

    # 准备数据
    train_set = NameDataset(is_train_set=True)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    test_set = NameDataset(is_train_set=False)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

    # 国家数量
    N_COUNTRY = train_set.getCountriesNum()

    # 声明RNN模型
    classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)
    if USE_GPU:
        device = torch.device("cuda:0")
        classifier.to(device)
    # 损失函数
    criterion = torch.nn.CrossEntropyLoss()
    # 优化器
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    start = time.time()
    acc_list = []
    for epoch in range(1, N_EPOCHS + 1):
        trainModel()
        acc = tttModel()
        acc_list.append(acc)
    # 画图
    plt.plot([i + 1 for i in range(len(acc_list))], acc_list)
    plt.show()

5.9 完整代码

import time
import torch
import csv
import gzip
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt


def name2list(name):
    arr = [ord(c) for c in name]
    return arr, len(arr)


# 创建Tensors
def make_tensors(names, countries):
    sequences_and_lengths = [name2list(name) for name in names]
    name_sequences = [sl[0] for sl in sequences_and_lengths]
    seq_lengths = torch.LongTensor([sl[1] for sl in sequences_and_lengths])
    countries = countries.long()

    seq_tensor = torch.zeros(len(name_sequences), seq_lengths.max()).long()
    for idx, (seq, seq_len) in enumerate(zip(name_sequences, seq_lengths), 0):
        seq_tensor[idx, :seq_len] = torch.LongTensor(seq)

    seq_lengths, perm_idx = seq_lengths.sort(dim=0, descending=True)
    seq_tensor = seq_tensor[perm_idx]
    countries = countries[perm_idx]
    return torch.LongTensor(seq_tensor), \
           torch.LongTensor(seq_lengths), \
           torch.LongTensor(countries)


# 国家名字数据集对象
class NameDataset(Dataset):
    def __init__(self, is_train_set=True):
        filename = '../dataset/names_train.csv.gz' if is_train_set else '../dataset/names_test.csv.gz'
        with gzip.open(filename, 'rt') as f:
            reader = csv.reader(f)
            rows = list(reader)
            self.names = [row[0] for row in rows]
            self.len = len(self.names)
            self.countries = [row[1] for row in rows]
            self.country_list = list(sorted(set(self.countries)))
            self.country_dict = self.getCountryDict()
            self.country_num = len(self.country_list)

    # 获取国家字典
    def getCountryDict(self):
        country_dict = dict()
        for idx, country_name in enumerate(self.country_list, 0):
            country_dict[country_name] = idx
        return country_dict

    # 获取国家数量
    def getCountriesNum(self):
        return self.country_num

    # 根据索引,返回国家的字符串
    def idx2country(self, index):
        return self.country_list[index]

    def __getitem__(self, index):
        return self.names[index], self.country_dict[self.countries[index]]

    def __len__(self):
        return self.len


# RNN分类器对象
class RNNClassifier(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1, bidirectional=True):
        super(RNNClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.n_directions = 2 if bidirectional else 1

        self.embedding = torch.nn.Embedding(input_size, hidden_size)
        self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, bidirectional=bidirectional)

        self.fc = torch.nn.Linear(hidden_size * self.n_directions, output_size)

    def forward(self, input, seq_len):
        input = input.t()  # 转置
        batch_size = input.size(1)

        hidden = self._init_hidden(batch_size)
        embedding = self.embedding(input.to(device))

        gru_input = pack_padded_sequence(embedding, seq_len)

        output, hidden = self.gru(gru_input.to(device), hidden.to(device))
        if self.n_directions == 2:
            hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim=1)
        else:
            hidden_cat = hidden[-1]

        fc_output = self.fc(hidden_cat)
        return fc_output

    def _init_hidden(self, batch_size):
        hidden = torch.zeros(self.n_layers * self.n_directions, batch_size, self.hidden_size)
        return torch.FloatTensor(hidden)


# 训练函数
def trainModel():
    total_loss = 0
    print('=' * 20, 'Epoch', epoch, '=' * 20)
    for i, (names, countries) in enumerate(train_loader, 1):
        inputs, seq_len, target = make_tensors(names, countries)
        output = classifier(inputs, seq_len)
        loss = criterion(output.to(device), target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if i % 10 == 0:
            print(f'[程序已运行{time.time() - start} 秒]', end='')
            print(f' - [{i * len(inputs)}/{len(train_set)}]', end='')
            print(f' , loss={total_loss / (i * len(inputs))}')


# 测试函数
def tttModel():
    correct = 0
    total = len(test_set)
    with torch.no_grad():
        for i, (names, countries) in enumerate(test_loader, 1):
            inputs, seq_len, target = make_tensors(names, countries)
            output = classifier(inputs, seq_len)
            pred = output.max(dim=1, keepdim=True)[1]
            correct += pred.eq(target.to(device).view_as(pred)).sum().item()
        percent = '%.2f' % (100 * correct / total)
        print(f'在训练集上评估模型: Accuracy {correct}/{total} {percent}%')
    return correct / total


if __name__ == '__main__':
    # 参数
    HIDDEN_SIZE = 100  # 隐藏层尺寸
    BATCH_SIZE = 256  #
    N_LAYER = 2
    N_EPOCHS = 50  # 迭代次数
    N_CHARS = 128  # 字符长度
    USE_GPU = True  # 是否启用GPU加速

    # 准备数据
    train_set = NameDataset(is_train_set=True)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    test_set = NameDataset(is_train_set=False)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

    # 国家数量
    N_COUNTRY = train_set.getCountriesNum()

    # 声明RNN模型
    classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_COUNTRY, N_LAYER)
    if USE_GPU:
        device = torch.device("cuda:0")
        classifier.to(device)
    # 损失函数
    criterion = torch.nn.CrossEntropyLoss()
    # 优化器
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    start = time.time()
    acc_list = []
    for epoch in range(1, N_EPOCHS + 1):
        trainModel()
        acc = tttModel()
        acc_list.append(acc)
    # 画图
    plt.plot([i + 1 for i in range(len(acc_list))], acc_list)
    plt.show()

5.10 运行输出

模型在测试集上的正确率迭代图:
在这里插入图片描述
控制台输出:

==================== Epoch 1 ====================
[程序已运行0.44780421257019043] - [2560/13374] , loss=0.00890564052388072
[程序已运行0.6133613586425781] - [5120/13374] , loss=0.007599683245643973
[程序已运行0.7878952026367188] - [7680/13374] , loss=0.006917052948847413
[程序已运行0.9584388732910156] - [10240/13374] , loss=0.006446240993682295
[程序已运行1.1319754123687744] - [12800/13374] , loss=0.0060893167182803154
在训练集上评估模型: Accuracy 4472/6700 66.75%
==================== Epoch 2 ====================
[程序已运行1.5678095817565918] - [2560/13374] , loss=0.004120685928501189
[程序已运行1.7363591194152832] - [5120/13374] , loss=0.004010635381564498
[程序已运行1.9069037437438965] - [7680/13374] , loss=0.00399712462288638
[程序已运行2.0754520893096924] - [10240/13374] , loss=0.003917965857544914
[程序已运行2.2400126457214355] - [12800/13374] , loss=0.0038193828100338578
在训练集上评估模型: Accuracy 4984/6700 74.39%
==================== Epoch 3 ====================
[程序已运行2.7569212913513184] - [2560/13374] , loss=0.003344046091660857
[程序已运行2.928469181060791] - [5120/13374] , loss=0.003259772143792361
[程序已运行3.1019983291625977] - [7680/13374] , loss=0.003187967735963563
[程序已运行3.274538040161133] - [10240/13374] , loss=0.0031107491464354097
[程序已运行3.442089319229126] - [12800/13374] , loss=0.0030547570576891303
在训练集上评估模型: Accuracy 5251/6700 78.37%
==================== Epoch 4 ====================
[程序已运行3.8719401359558105] - [2560/13374] , loss=0.0026993038831278683
[程序已运行4.046473503112793] - [5120/13374] , loss=0.002674128650687635
[程序已运行4.216020822525024] - [7680/13374] , loss=0.0026587502487624686
[程序已运行4.38556694984436] - [10240/13374] , loss=0.0025898678810335695
[程序已运行4.550126552581787] - [12800/13374] , loss=0.0025764494528993966
在训练集上评估模型: Accuracy 5364/6700 80.06%
==================== Epoch 5 ====================
[程序已运行4.99593448638916] - [2560/13374] , loss=0.002204193570651114
[程序已运行5.161492109298706] - [5120/13374] , loss=0.002289206086425111
[程序已运行5.331039667129517] - [7680/13374] , loss=0.0022666183416731656
[程序已运行5.49859094619751] - [10240/13374] , loss=0.002261629086569883
[程序已运行5.671129941940308] - [12800/13374] , loss=0.0022295594890601933
在训练集上评估模型: Accuracy 5463/6700 81.54%
==================== Epoch 6 ====================
[程序已运行6.233642578125] - [2560/13374] , loss=0.0019896522513590752
[程序已运行6.456540107727051] - [5120/13374] , loss=0.001986617426155135
[程序已运行6.664837121963501] - [7680/13374] , loss=0.0019825143235114714
[程序已运行6.833385467529297] - [10240/13374] , loss=0.0020075739128515126
[程序已运行7.032860040664673] - [12800/13374] , loss=0.0020189793314784764
在训练集上评估模型: Accuracy 5509/6700 82.22%
==================== Epoch 7 ====================
[程序已运行7.519379138946533] - [2560/13374] , loss=0.0018858694704249502
[程序已运行7.713961362838745] - [5120/13374] , loss=0.0019061228667851537
[程序已运行7.924145698547363] - [7680/13374] , loss=0.0018647492048330604
[程序已运行8.105660200119019] - [10240/13374] , loss=0.0018422425084281713
[程序已运行8.272214412689209] - [12800/13374] , loss=0.001829312415793538
在训练集上评估模型: Accuracy 5554/6700 82.90%
==================== Epoch 8 ====================
[程序已运行8.776322841644287] - [2560/13374] , loss=0.0017062094528228044
[程序已运行8.956839084625244] - [5120/13374] , loss=0.0017181114875711502
[程序已运行9.128380060195923] - [7680/13374] , loss=0.0016691674555962285
[程序已运行9.299922227859497] - [10240/13374] , loss=0.001663037418620661
[程序已运行9.480438709259033] - [12800/13374] , loss=0.0016579296253621577
在训练集上评估模型: Accuracy 5622/6700 83.91%
==================== Epoch 9 ====================
[程序已运行10.042935371398926] - [2560/13374] , loss=0.00138091582339257
[程序已运行10.261351108551025] - [5120/13374] , loss=0.0014375038270372897
[程序已运行10.564762353897095] - [7680/13374] , loss=0.0014615616644732654
[程序已运行10.782928466796875] - [10240/13374] , loss=0.0014871433144435287
[程序已运行10.983396530151367] - [12800/13374] , loss=0.001485794959589839
在训练集上评估模型: Accuracy 5639/6700 84.16%
==================== Epoch 10 ====================
[程序已运行11.527957201004028] - [2560/13374] , loss=0.0013533846475183963
[程序已运行11.727681636810303] - [5120/13374] , loss=0.0013366769824642688
[程序已运行11.922213315963745] - [7680/13374] , loss=0.0013372401415836066
[程序已运行12.146097660064697] - [10240/13374] , loss=0.0013334597359062172
[程序已运行12.360528945922852] - [12800/13374] , loss=0.0013430006837006658
在训练集上评估模型: Accuracy 5671/6700 84.64%
==================== Epoch 11 ====================
[程序已运行12.899601221084595] - [2560/13374] , loss=0.0012242367956787348
[程序已运行13.09657597541809] - [5120/13374] , loss=0.0012018651730613783
[程序已运行13.297226905822754] - [7680/13374] , loss=0.0011939232140624275
[程序已运行13.494091510772705] - [10240/13374] , loss=0.0012114218538044953
[程序已运行13.68766736984253] - [12800/13374] , loss=0.0012003930215723812
在训练集上评估模型: Accuracy 5694/6700 84.99%
==================== Epoch 12 ====================
[程序已运行14.210429906845093] - [2560/13374] , loss=0.0009987134893890471
[程序已运行14.412419557571411] - [5120/13374] , loss=0.001071632924140431
[程序已运行14.66275668144226] - [7680/13374] , loss=0.001076526817632839
[程序已运行14.856532573699951] - [10240/13374] , loss=0.0010921183042228223
[程序已运行15.086088180541992] - [12800/13374] , loss=0.0010934956604614853
在训练集上评估模型: Accuracy 5662/6700 84.51%
==================== Epoch 13 ====================
[程序已运行15.603407859802246] - [2560/13374] , loss=0.0009401162387803197
[程序已运行15.79611325263977] - [5120/13374] , loss=0.000941314865485765
[程序已运行15.991790294647217] - [7680/13374] , loss=0.0009610804088879376
[程序已运行16.179932117462158] - [10240/13374] , loss=0.0009680362229119055
[程序已运行16.40265130996704] - [12800/13374] , loss=0.000982134909136221
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 14 ====================
[程序已运行16.909828424453735] - [2560/13374] , loss=0.0007867097272537648
[程序已运行17.106796979904175] - [5120/13374] , loss=0.000815554044675082
[程序已运行17.292539358139038] - [7680/13374] , loss=0.0008367548192230364
[程序已运行17.472803592681885] - [10240/13374] , loss=0.0008467563966405578
[程序已运行17.639357805252075] - [12800/13374] , loss=0.0008704598969779909
在训练集上评估模型: Accuracy 5668/6700 84.60%
==================== Epoch 15 ====================
[程序已运行18.066723585128784] - [2560/13374] , loss=0.0007090812985552474
[程序已运行18.24724054336548] - [5120/13374] , loss=0.0007125684132915922
[程序已运行18.403822898864746] - [7680/13374] , loss=0.0007363049614165599
[程序已运行18.56738543510437] - [10240/13374] , loss=0.0007670420709473547
[程序已运行18.725961923599243] - [12800/13374] , loss=0.0007662988040829078
在训练集上评估模型: Accuracy 5691/6700 84.94%
==================== Epoch 16 ====================
[程序已运行19.231609582901] - [2560/13374] , loss=0.0006500366900581867
[程序已运行19.412126302719116] - [5120/13374] , loss=0.0006256612250581384
[程序已运行19.586659908294678] - [7680/13374] , loss=0.0006333356761994461
[程序已运行19.75919795036316] - [10240/13374] , loss=0.0006514595239423216
[程序已运行19.923759937286377] - [12800/13374] , loss=0.0006744604662526399
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 17 ====================
[程序已运行20.44287872314453] - [2560/13374] , loss=0.0005562761740293354
[程序已运行20.62125062942505] - [5120/13374] , loss=0.0005988578370306641
[程序已运行20.796781301498413] - [7680/13374] , loss=0.0005926105329611649
[程序已运行20.967325687408447] - [10240/13374] , loss=0.000587682421610225
[程序已运行21.13288402557373] - [12800/13374] , loss=0.0005947937141172587
在训练集上评估模型: Accuracy 5690/6700 84.93%
==================== Epoch 18 ====================
[程序已运行21.58068537712097] - [2560/13374] , loss=0.0004808117635548115
[程序已运行21.750232458114624] - [5120/13374] , loss=0.0004957637938787229
[程序已运行21.91977834701538] - [7680/13374] , loss=0.0005180548294447362
[程序已运行22.09730362892151] - [10240/13374] , loss=0.0005155952792847529
[程序已运行22.261864185333252] - [12800/13374] , loss=0.000532688939711079
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 19 ====================
[程序已运行22.743576526641846] - [2560/13374] , loss=0.00045145115873310713
[程序已运行22.916114330291748] - [5120/13374] , loss=0.0004681079401052557
[程序已运行23.086658239364624] - [7680/13374] , loss=0.0004730896607118969
[程序已运行23.260194301605225] - [10240/13374] , loss=0.00047693301239633
[程序已运行23.420764923095703] - [12800/13374] , loss=0.00048351394420024005
在训练集上评估模型: Accuracy 5686/6700 84.87%
==================== Epoch 20 ====================
[程序已运行23.879538536071777] - [2560/13374] , loss=0.000395867633051239
[程序已运行24.058061122894287] - [5120/13374] , loss=0.00040292235498782245
[程序已运行24.22611951828003] - [7680/13374] , loss=0.00041214636488196753
[程序已运行24.40364646911621] - [10240/13374] , loss=0.0004200873590889387
[程序已运行24.599121570587158] - [12800/13374] , loss=0.00042729771870654077
在训练集上评估模型: Accuracy 5683/6700 84.82%
==================== Epoch 21 ====================
[程序已运行25.093798637390137] - [2560/13374] , loss=0.0003232844435842708
[程序已运行25.266338109970093] - [5120/13374] , loss=0.0003470184303296264
[程序已运行25.42990016937256] - [7680/13374] , loss=0.00036114812683081255
[程序已运行25.605431079864502] - [10240/13374] , loss=0.00036963203929190057
[程序已运行25.77896738052368] - [12800/13374] , loss=0.00038419733085902406
在训练集上评估模型: Accuracy 5656/6700 84.42%
==================== Epoch 22 ====================
[程序已运行26.23973536491394] - [2560/13374] , loss=0.00035485800908645614
[程序已运行26.42025327682495] - [5120/13374] , loss=0.000331850739166839
[程序已运行26.59478497505188] - [7680/13374] , loss=0.00033587128127692266
[程序已运行26.76333522796631] - [10240/13374] , loss=0.0003520488495269092
[程序已运行26.939862489700317] - [12800/13374] , loss=0.00036097451025852935
在训练集上评估模型: Accuracy 5633/6700 84.07%
==================== Epoch 23 ====================
[程序已运行27.467453002929688] - [2560/13374] , loss=0.0003026763326488435
[程序已运行27.70481777191162] - [5120/13374] , loss=0.0003143761219689623
[程序已运行27.911266088485718] - [7680/13374] , loss=0.0003264661878347397
[程序已运行28.186530113220215] - [10240/13374] , loss=0.0003306483005871996
[程序已运行28.46631908416748] - [12800/13374] , loss=0.0003335028060246259
在训练集上评估模型: Accuracy 5673/6700 84.67%
==================== Epoch 24 ====================
[程序已运行29.179603815078735] - [2560/13374] , loss=0.00027135475975228476
[程序已运行29.430758237838745] - [5120/13374] , loss=0.00028713325373246333
[程序已运行29.679874897003174] - [7680/13374] , loss=0.0003023453294493568
[程序已运行29.934645414352417] - [10240/13374] , loss=0.000312889136330341
[程序已运行30.228501319885254] - [12800/13374] , loss=0.0003050918216467835
在训练集上评估模型: Accuracy 5648/6700 84.30%
==================== Epoch 25 ====================
[程序已运行30.873358488082886] - [2560/13374] , loss=0.00023736636503599585
[程序已运行31.12848401069641] - [5120/13374] , loss=0.000262370355630992
[程序已运行31.37047266960144] - [7680/13374] , loss=0.00028469724299308533
[程序已运行31.587894439697266] - [10240/13374] , loss=0.00029130332404747605
[程序已运行31.81080985069275] - [12800/13374] , loss=0.000298141545499675
在训练集上评估模型: Accuracy 5650/6700 84.33%
==================== Epoch 26 ====================
[程序已运行32.386191606521606] - [2560/13374] , loss=0.0002229748701211065
[程序已运行32.59639310836792] - [5120/13374] , loss=0.00024768941948423163
[程序已运行32.80773115158081] - [7680/13374] , loss=0.0002617782947102872
[程序已运行33.03980994224548] - [10240/13374] , loss=0.0002768368027318502
[程序已运行33.29946231842041] - [12800/13374] , loss=0.0002899381099268794
在训练集上评估模型: Accuracy 5641/6700 84.19%
==================== Epoch 27 ====================
[程序已运行33.88727021217346] - [2560/13374] , loss=0.00024374632048420608
[程序已运行34.09471583366394] - [5120/13374] , loss=0.00025897946034092456
[程序已运行34.289196252822876] - [7680/13374] , loss=0.0002538432978326455
[程序已运行34.47469925880432] - [10240/13374] , loss=0.000259564047519234
[程序已运行34.67316937446594] - [12800/13374] , loss=0.00027740672609070314
在训练集上评估模型: Accuracy 5646/6700 84.27%
==================== Epoch 28 ====================
[程序已运行35.17283344268799] - [2560/13374] , loss=0.00024596738221589477
[程序已运行35.37529134750366] - [5120/13374] , loss=0.00025625340931583195
[程序已运行35.612366676330566] - [7680/13374] , loss=0.00025634167104726657
[程序已运行35.83150577545166] - [10240/13374] , loss=0.00026004306309914684
[程序已运行36.03250527381897] - [12800/13374] , loss=0.00026288528344593944
在训练集上评估模型: Accuracy 5638/6700 84.15%
==================== Epoch 29 ====================
[程序已运行36.57138133049011] - [2560/13374] , loss=0.00021622761414619163
[程序已运行36.77258634567261] - [5120/13374] , loss=0.00022565958206541837
[程序已运行36.98577690124512] - [7680/13374] , loss=0.0002272359988031288
[程序已运行37.19248104095459] - [10240/13374] , loss=0.000239311121913488
[程序已运行37.42102932929993] - [12800/13374] , loss=0.00026318065298255535
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 30 ====================
[程序已运行37.98303580284119] - [2560/13374] , loss=0.0002419873373582959
[程序已运行38.19469690322876] - [5120/13374] , loss=0.00024032749206526205
[程序已运行38.433921813964844] - [7680/13374] , loss=0.00023429312568623573
[程序已运行38.63934087753296] - [10240/13374] , loss=0.000236004658654565
[程序已运行38.855393171310425] - [12800/13374] , loss=0.00024584917380707336
在训练集上评估模型: Accuracy 5652/6700 84.36%
==================== Epoch 31 ====================
[程序已运行39.43484330177307] - [2560/13374] , loss=0.00018982139372383244
[程序已运行39.639156103134155] - [5120/13374] , loss=0.0001874874855275266
[程序已运行39.86286163330078] - [7680/13374] , loss=0.0002010059698174397
[程序已运行40.10820508003235] - [10240/13374] , loss=0.00022044739162083716
[程序已运行40.321282148361206] - [12800/13374] , loss=0.00023711820482276381
在训练集上评估模型: Accuracy 5650/6700 84.33%
==================== Epoch 32 ====================
[程序已运行40.8946692943573] - [2560/13374] , loss=0.00015288349750335327
[程序已运行41.10710620880127] - [5120/13374] , loss=0.0001839709879277507
[程序已运行41.33549523353577] - [7680/13374] , loss=0.00019845828088970545
[程序已运行41.57286047935486] - [10240/13374] , loss=0.00021456888480315684
[程序已运行41.8036105632782] - [12800/13374] , loss=0.0002269831285229884
在训练集上评估模型: Accuracy 5666/6700 84.57%
==================== Epoch 33 ====================
[程序已运行42.351365089416504] - [2560/13374] , loss=0.00020009858708363025
[程序已运行42.55391049385071] - [5120/13374] , loss=0.00021211478597251698
[程序已运行42.762874364852905] - [7680/13374] , loss=0.00021397633778785045
[程序已运行42.9832968711853] - [10240/13374] , loss=0.0002237053031421965
[程序已运行43.203397274017334] - [12800/13374] , loss=0.00022588698309846221
在训练集上评估模型: Accuracy 5663/6700 84.52%
==================== Epoch 34 ====================
[程序已运行43.75180983543396] - [2560/13374] , loss=0.000175720240076771
[程序已运行43.96774888038635] - [5120/13374] , loss=0.0001822654259740375
[程序已运行44.17185306549072] - [7680/13374] , loss=0.00019087508820424167
[程序已运行44.37160062789917] - [10240/13374] , loss=0.0002025523937845719
[程序已运行44.58753275871277] - [12800/13374] , loss=0.0002117827812617179
在训练集上评估模型: Accuracy 5661/6700 84.49%
==================== Epoch 35 ====================
[程序已运行45.14763021469116] - [2560/13374] , loss=0.00018244399325340055
[程序已运行45.38248586654663] - [5120/13374] , loss=0.0001909391281515127
[程序已运行45.58654308319092] - [7680/13374] , loss=0.00019028932996055422
[程序已运行45.78802943229675] - [10240/13374] , loss=0.00019807684529951075
[程序已运行45.99347996711731] - [12800/13374] , loss=0.00020920663388096728
在训练集上评估模型: Accuracy 5680/6700 84.78%
==================== Epoch 36 ====================
[程序已运行46.53318977355957] - [2560/13374] , loss=0.0002301990520209074
[程序已运行46.713706970214844] - [5120/13374] , loss=0.00019428682790021413
[程序已运行46.90220332145691] - [7680/13374] , loss=0.00020380564043686415
[程序已运行47.091697454452515] - [10240/13374] , loss=0.00020663877294282428
[程序已运行47.291162967681885] - [12800/13374] , loss=0.0002089952616370283
在训练集上评估模型: Accuracy 5676/6700 84.72%
==================== Epoch 37 ====================
[程序已运行47.806785345077515] - [2560/13374] , loss=0.0001326974183029961
[程序已运行47.99926948547363] - [5120/13374] , loss=0.00016608303412795066
[程序已运行48.17978763580322] - [7680/13374] , loss=0.00018860103252033394
[程序已运行48.39720582962036] - [10240/13374] , loss=0.00019829184930131305
[程序已运行48.613649129867554] - [12800/13374] , loss=0.00020632889354601502
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 38 ====================
[程序已运行49.16527962684631] - [2560/13374] , loss=0.00017558708641445263
[程序已运行49.377071380615234] - [5120/13374] , loss=0.0001803442672098754
[程序已运行49.61643576622009] - [7680/13374] , loss=0.00019150104708387516
[程序已运行49.82301330566406] - [10240/13374] , loss=0.00020083691142644967
[程序已运行50.0354540348053] - [12800/13374] , loss=0.00020664259296609088
在训练集上评估模型: Accuracy 5665/6700 84.55%
==================== Epoch 39 ====================
[程序已运行50.57412886619568] - [2560/13374] , loss=0.0001647219163714908
[程序已运行50.781656980514526] - [5120/13374] , loss=0.00017860122316051274
[程序已运行50.985915660858154] - [7680/13374] , loss=0.00018149100748511652
[程序已运行51.19026255607605] - [10240/13374] , loss=0.00018654396444617305
[程序已运行51.396902322769165] - [12800/13374] , loss=0.00019622410647571087
在训练集上评估模型: Accuracy 5674/6700 84.69%
==================== Epoch 40 ====================
[程序已运行51.950305223464966] - [2560/13374] , loss=0.0001813760813092813
[程序已运行52.168620586395264] - [5120/13374] , loss=0.0002006525617616717
[程序已运行52.388041496276855] - [7680/13374] , loss=0.00020146279275650157
[程序已运行52.609692335128784] - [10240/13374] , loss=0.00020502966435742564
[程序已运行52.82223105430603] - [12800/13374] , loss=0.00020634495449485258
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 41 ====================
[程序已运行53.38243508338928] - [2560/13374] , loss=0.00015071030429680832
[程序已运行53.587127923965454] - [5120/13374] , loss=0.00017189887112181168
[程序已运行53.806541204452515] - [7680/13374] , loss=0.00017305417253131358
[程序已运行54.02480912208557] - [10240/13374] , loss=0.00019079587727901525
[程序已运行54.2585232257843] - [12800/13374] , loss=0.00020163424400379882
在训练集上评估模型: Accuracy 5648/6700 84.30%
==================== Epoch 42 ====================
[程序已运行54.82281470298767] - [2560/13374] , loss=0.00013891297894588205
[程序已运行55.025827407836914] - [5120/13374] , loss=0.00015851605949137592
[程序已运行55.231240034103394] - [7680/13374] , loss=0.00016577779958121635
[程序已运行55.46313452720642] - [10240/13374] , loss=0.00018027596570391324
[程序已运行55.67412066459656] - [12800/13374] , loss=0.0001966867937153438
在训练集上评估模型: Accuracy 5665/6700 84.55%
==================== Epoch 43 ====================
[程序已运行56.20661902427673] - [2560/13374] , loss=0.00015996379297575914
[程序已运行56.41858148574829] - [5120/13374] , loss=0.00016856673537404276
[程序已运行56.639917612075806] - [7680/13374] , loss=0.00018089073128066958
[程序已运行56.8443865776062] - [10240/13374] , loss=0.0001840785967942793
[程序已运行57.06223940849304] - [12800/13374] , loss=0.00019384760729735718
在训练集上评估模型: Accuracy 5645/6700 84.25%
==================== Epoch 44 ====================
[程序已运行57.572874546051025] - [2560/13374] , loss=0.00015061586382216773
[程序已运行57.78630328178406] - [5120/13374] , loss=0.0001388985794619657
[程序已运行57.9957435131073] - [7680/13374] , loss=0.00016478789621032774
[程序已运行58.19321537017822] - [10240/13374] , loss=0.00017915765802172244
[程序已运行58.42410707473755] - [12800/13374] , loss=0.00018903483942267485
在训练集上评估模型: Accuracy 5671/6700 84.64%
==================== Epoch 45 ====================
[程序已运行58.94371771812439] - [2560/13374] , loss=0.00015125685386010445
[程序已运行59.1222403049469] - [5120/13374] , loss=0.0001521789192338474
[程序已运行59.306747913360596] - [7680/13374] , loss=0.00016950705078973746
[程序已运行59.48726511001587] - [10240/13374] , loss=0.00017524105423945003
[程序已运行59.67775559425354] - [12800/13374] , loss=0.0001805883324414026
在训练集上评估模型: Accuracy 5667/6700 84.58%
==================== Epoch 46 ====================
[程序已运行60.20154929161072] - [2560/13374] , loss=0.00014828720522928052
[程序已运行60.40001821517944] - [5120/13374] , loss=0.00015882042025623378
[程序已运行60.58452534675598] - [7680/13374] , loss=0.0001613501580626083
[程序已运行60.77501559257507] - [10240/13374] , loss=0.00016973716010397765
[程序已运行60.948551416397095] - [12800/13374] , loss=0.0001806799619225785
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 47 ====================
[程序已运行61.3913676738739] - [2560/13374] , loss=0.00015815104998182506
[程序已运行61.57487750053406] - [5120/13374] , loss=0.00016703655710443855
[程序已运行61.75738835334778] - [7680/13374] , loss=0.00017824470463286465
[程序已运行61.92992830276489] - [10240/13374] , loss=0.00017852899200079264
[程序已运行62.09648156166077] - [12800/13374] , loss=0.00018316529472940603
在训练集上评估模型: Accuracy 5659/6700 84.46%
==================== Epoch 48 ====================
[程序已运行62.557249307632446] - [2560/13374] , loss=0.00015278960127034223
[程序已运行62.73379874229431] - [5120/13374] , loss=0.00015682993143855127
[程序已运行62.920299768447876] - [7680/13374] , loss=0.00016720300166828868
[程序已运行63.097824573516846] - [10240/13374] , loss=0.00017679908323771089
[程序已运行63.27834177017212] - [12800/13374] , loss=0.00018408266638289205
在训练集上评估模型: Accuracy 5632/6700 84.06%
==================== Epoch 49 ====================
[程序已运行63.809921741485596] - [2560/13374] , loss=0.00013651168810611126
[程序已运行63.99043846130371] - [5120/13374] , loss=0.00015391943106806137
[程序已运行64.16397380828857] - [7680/13374] , loss=0.00016998272246079675
[程序已运行64.34748315811157] - [10240/13374] , loss=0.0001738338686664065
[程序已运行64.52999496459961] - [12800/13374] , loss=0.0001867358752497239
在训练集上评估模型: Accuracy 5660/6700 84.48%
==================== Epoch 50 ====================
[程序已运行64.99973917007446] - [2560/13374] , loss=0.00014780706333112904
[程序已运行65.21316909790039] - [5120/13374] , loss=0.00015933765134832357
[程序已运行65.40266251564026] - [7680/13374] , loss=0.0001664065877169681
[程序已运行65.57220840454102] - [10240/13374] , loss=0.00017008764552883804
[程序已运行65.75272583961487] - [12800/13374] , loss=0.00018112511475919745
在训练集上评估模型: Accuracy 5646/6700 84.27%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/89392.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nacos--多环境的实现方案

原文网址:Nacos--多环境的实现方案_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍Nacos实现多环境的方案。 方案概述 多环境有很多方案,如下: 单租户方案(适用于项目很少的场景) 命名空间区分环境,GR…

Python简介

Python简介 目录1. 概述2. 安装3. 编译器4. 注释5. 缩进6. 编码规范7. 基本输入输出使用print()函数输出使用input()函数输入8. 练习1. 概述 Python的中文意思是蟒蛇,python是一种面向对象的解释型的计算机程序设计语言。支持面向过程,面向对象&#xff…

(十四)Vue之收集表单数据

文章目录v-model的三个修饰符收集文本框收集单选按钮收集复选框收集下拉列表收集文本域演示程序Vue学习目录 上一篇:(十三)Vue之监测数据改变的原理 v-model的三个修饰符 v-model的三个修饰符: lazy:失去焦点再收集…

MIT6.830-2022-lab2实验思路详细讲解

目录一、Exercise1.1、Exercise1: Filter and Join1.2、Exercise2: Aggregates1.3、Exercise 3:HeapFile Mutability1.4、Exercise 4:Insertion and deletion1.5、Exercise 5: Page eviction二、总结一、Exercise 1.1…

人工智能课后作业_python实现A*算法实现8数码问题(附源码)

3 A*算法实现8数码问题 3.1算法介绍3.2实验代码3.3实验结果3.4实验总结 3.1算法介绍 Astar算法是一种求解最短路径最有效的直接搜索方法,也是许多其他问题的常用启发式算法。它的启发函数为f(n)g(n)h(n),其中,f(n) 是从初始状态经由状态n到目标状态的…

竞拍拍卖管理系统

开发工具(eclipse/idea/vscode等): 数据库(sqlite/mysql/sqlserver等): 功能模块(请用文字描述,至少200字): 网站前台:关于我们、联系我门、公告信息、拍卖物品,拍卖完成 管理员功影: 1、管理关…

信贷反欺诈体系介绍及其策略规则应用

在信贷业务的风控体系中,反欺诈始终是一个重要话题,与信用评估构成的贷前风控两大模块,对于贷前风险的防范控制发挥着决定性作用。反欺诈虽然在理解层面上感觉略显简单,但由于场景的复杂性与丰富度,使得反欺诈在研究开…

PD QC快充诱骗取电方案:输出9V12V15V20V

手机快充充电器或充电宝,在没有与手机通讯时,快充充电器相当于普通的充电器只输出5V电压,要想得到充电器的9V/12V等电压,可以使用快充取电电路。 或者也可以使用电子元件来搭建诱骗电路,但是和专用的取电芯片方案相比&…

Part 1:RPA的发展历程

Robot一词的来源 捷克科幻小说家卡雷尔恰佩克创作,于1921 年在布拉格首映的《罗素姆万能机器人》作品中首次出现“robot”(机器人)一词。这个词源于捷克语的“robota”,意思是“苦力”。恰佩克的机器人原本是为它们的人类主人服务…

Python使用Selenium Webdriver爬取网页所有内容

Python使用Selenium Webdriver爬取网页所有内容一、为什么我抓不到网页的全部html内容二、Selenium的基本使用三、使用Selenium抓取全部HTML一、为什么我抓不到网页的全部html内容 有时候,我们在用urllib或者requests库抓取页面时,得到的html源代码和浏…

4年测试在岗,薪资却被春招来的年轻人超过了,其实你一直在假装努力~

最近和一位同行朋友聊天,一开始大家也没有谈工作,毕竟是出来聚聚,放松一下,吃饭的时候,喝了点小酒,酒过三巡,这个朋友开始诉苦水,大概意思嘞,我给大家概况一下&#xff0…

STM32F4的关键要点分析

1. 从以上截图信息可以看出: 1.当外设数据宽度和内存数据宽度不相等时,要传输的数据项数目的数据宽度由外设数据宽度确定; 2.在直接模式下(不使用FIFO),不能进行数据的封装/解封,且源数据宽度和…

Docker-Docker安装nginx

目录 一,容器之间的相互通信 ping 1.1 两个容器在同一网段 1.2 两个容器在不同网段 二,安装Nginx 2.1 nginx是什么 安装步骤 2.4 部署前端项目 上传项目 步骤 一,容器之间的相互通信 ping 1.1 两个容器在同一网段 1.2 两个容器在不同网段…

旋转机械 | 基于ANSYS WB平台的滑动轴承分析工具(一)

导读:本文主要针对Tribo-X inside ANSYS的功能及各方向应用实例进行介绍,限于篇幅关系会分五篇进行介绍,第一篇主要结合软件的需求、理论、功能及应用方向进行介绍,第二篇至第五篇将结合具体应用方向的示例进行介绍。本篇为第一篇…

软件测试工程师的简历项目经验该怎么写?

想要找一份高薪的软件测试工作,简历项目必不可少(即使是应届生,你也要写上实习项目)。所以很多自学的朋友找工作时会碰到一个令人颇感绝望的拦路虎:个人并没有实际的项目工作经验怎么办? 怎么办&#xff1…

【PS】画笔工具

目录 画直线 拾取颜色 改变画笔大小 改变画笔硬度 不透明度 流量 画笔预设 自定义图片做笔刷 工具预设 画笔面板 画直线 Shift键可画出直线只用点两个点就画出直线:先点一个点,按住shift,再在别处点一点,这时候直线就形…

Python 中如何使用pybind11调用C++

Python 中如何使用pybind11调用C1. pybind11简介与环境安装2. 求和函数3. STL和python内建数据类型的对应关系3.1 **返回vector**3.2 **返回struct**4. pybind11与numpy图像数据接口和速度对比:以图像rgb转化为gray的例子Reference: 混合编程:如何用pyb…

银联卡8583协议小额免密免签交易总结

之前做过金融支付这块儿。到过北京石景山区银行卡检测中心过检PBOC的level2认证,去过上海银联总部和湖南银联对接银联卡和扫码支付。对金融支付和卡交易这块儿熟悉。现在这块儿知识用不上了总结下留作备忘,同时分享给有需要的人。 关于免密免…

【云原生 | Kubernetes 实战】12、K8s 四层代理 Service 入门到企业实战应用(下)

目录 一、创建 Service:type 类型是 NodePort 1.1 创建一个 pod 资源 1.2 创建 service,代理 pod 在集群外访问 service: 数据转发流程: 二、创建 Service:type 类型是 ExternalName 2.1 创建 pod 2.2 创建…

相关数据库

h2 需要用以下 初始化一下 第一次启动需要加入下面代码 h2 创建表 可以直接用jdbc 然后进行测试 不需要链接mysql redis 想要启动redis 现在 该目录下 输入俩个cmd 一个cmd 输入redis-cli 到启动太 输入 shutdown 然后再另一个cmd 输入 redis-server.exe redis.windows.con…