【深度学习】基于华为MindSpore和pytorch的卷积神经网络LeNet5实现MNIST手写识别

news2024/9/28 7:25:00

1 实验内容简介

1.1 实验目的

(1)熟练掌握卷积、池化概念;

(2)熟练掌握卷积神经网络的基本原理;

(3)熟练掌握各种卷积神经网络框架单元;

(4)熟练掌握经典卷积神经网络模型。

 

1.2 实验内容及要求

请基于pytorch和mindspore平台,利用MNIST数据集,选择一个典型卷积模型,构建一个自己的卷积模型,以分类的准确度和混淆矩阵为衡量指标,分析两个模型的分类精度。

要求:pytorch可与tensorflow替换,但mindspore为必选平台。(mindspore可以在华为云ModelArts上实现);自己构建的模型必须与经典模型的网络结构有显著区分,但总体分类准确度需要在97%以上记为合格(不扣分)。

 

1.3 实验数据集介绍

1.3.1 数据集简介

MNIST数据集(Mixed National Institute of Standards and Technology Database)是一个用来训练各种图像处理系统的二进制图像数据集,广泛应用于机器学习中的训练和测试。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

 

1.3.2 数据集详细信息

(1)数据量

训练集60000张图像,其中30000张来自NIST的Special Database 3,30000张来自NIST的Special Database 1。测试集10000张图像,其中5000张来自NIST的Special Database 3,5000张来自NIST的Special Database 1。

(2)标注情况

每张图像都有标注。

(3)标注类别

共10个类别,每个类别代表0~9之间的一个数字,每张图像只有一个类别。

 

1.3.3 数据集文件结构

(1)目录结构

·解压前

dataset_compressed/

├── t10k-images-idx3-ubyte.gz        #测试集图像压缩包(1648877 bytes)

├── t10k-labels-idx1-ubyte.gz         #测试集标签压缩包(4542 bytes)

├── train-images-idx3-ubyte.gz        #训练集图像压缩包(9912422 bytes)

└── train-labels-idx1-ubyte.gz         #训练集标签压缩包(28881 bytes)

·解压后

dataset_uncompressed/

├── t10k-images-idx3-ubyte                #测试集图像数据

├── t10k-labels-idx1-ubyte                 #测试集标签数据

├── train-images-idx3-ubyte                #训练集图像数据

└── train-labels-idx1-ubyte                 #训练集标签数据

 

(2)文件结构

MNIST数据集将图像和标签都以矩阵的形式存储于一种称为idx格式的二进制文件中。该数据集的4个二进制文件的存储格式分别如下:

·训练集标签数据 (train-labels-idx1-ubyte)

d4fb7f04b1ea4ef8831e349ac090f8df.png

·训练集图像数据(train-images-idx3-ubyte)

15801bad16894ee7bc7ba1a3c09aaa97.png

·测试集标签数据(t10k-labels-idx1-ubyte)

34d6499b0ac14f4ca51d7bb4bd60b511.png

·测试集图像数据 (t10k-images-idx3-ubyte)

7289befbc01a43099718072e5f8167e0.png

 

2 算法原理阐述

2.1 LeNet-5网络结构

LeNet-5由LeCun等人提出于1998年提出,是一种用于手写体字符识别的非常高效的卷积神经网络,出自论文《Gradient-Based Learning Applied to Document Recognition》。LeNet-5的网络结构如下图所示:

f47bca5776994e94b7de5c21f3257271.png

整个 LeNet-5 网络总共包括7层(不含输入层),分别是:C1、S2、C3、S4、C5、F6和OUTPUT。输入二维图像(单通道),先经过两次卷积层到池化层,再经过全连接层,最后为输出层。接受输入图像大小为32×32=1024,输出对应10个类别的得分。LeNet-5 中的每一层结构如下:

(1)C1层是卷积层,使用6个5×5的卷积核,得到6组大小为28×28 = 784的特征映射。因此,C1 层的神经元数量为6×784 = 4 704,可训练参数数量为6×25 + 6 = 156,连接数为156×784 = 122304(包括偏置在内,下同)。     

(2)S2层为汇聚层,采样窗口为2×2,使用平均汇聚,并使用一个非线性函数。神经元个数为6×14×14 =1176,可训练参数数量为6×(1 + 1) = 12,连接数为6×196×(4 + 1) = 5880。

(3)C3 层为卷积层。LeNet-5 中用一个连接表来定义输入和输出特征映射之间的依赖关系。共使用60个5×5的卷积核,得到16 组大小为10×10的特征映射。如果不使用连接表,则需要96个5×5的卷积核。神经元数量为16 ×100 = 1600,可训练参数数量为(60×25) + 16 = 1516,连接数为100×1516 = 151600。

(4)S4 层是一个汇聚层,采样窗口为2×2,得到16个5×5大小的特征映射,可训练参数数量为16×2 = 32,连接数为16×25×(4 + 1) = 2 000。

(5)C5 层是一个卷积层,使用120×16 = 1 920 个5×5的卷积核,得到120 组大小为1×1的特征映射。C5 层的神经元数量为120,可训练参数数量为1 920×25 + 120 = 48120,连接数为120×(16×25 + 1) = 48120。

(6)F6 层是一个全连接层,有84 个神经元,可训练参数数量为84× (120 +1) = 10164。连接数和可训练参数个数相同,为10164。

(7)输出层:输出层由10 个径向基函数(Radial Basis Function,RBF)组成。

卷积层的每一个输出特征映射都依赖于所有输入特征映射,相当于卷积层的输入和输出特征映射之间是全连接的关系。实际上,这种全连接关系不是必须的。我们可以让每一个输出特征映射都依赖于少数几个输入特征映射。定义一个连接表(Link Table)来描述输入和输出特征映射之间的连接关系。在LeNet-5中,连接表的基本设定如下图所示。C3层的第0-5个特征映射依赖于S2层的特征映射组的每3个连续子集,第6-11个特征映射依赖于S2层的特征映射组的每4个连续子集,第12-14个特征映射依赖于S2层的特征映射的每4个不连续子集,第15个特征映射依赖于S2层的所有特征映射。

fa99aca0e0be4bf4bd2a981b46697d9d.jpeg

3 实验流程及代码实现

3.1 实验平台简介

3.1.1 MindSpore

MindSpore是华为公司自研的最佳匹配昇腾AI处理器算力的全场景深度学习框架,为数据科学家和算法工程师提供设计友好、运行高效的开发体验,推动人工智能软硬件应用生态繁荣发展,目前MindSpore支持在EulerOS、Ubuntu、Windows系统上安装。

 

3.1.2 pytorch

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。

 

3.2 评价指标

3.2.1 混淆矩阵

混淆矩阵(Confusion Matrix)又被称为错误矩阵,通过它可以直观地观察到算法的效果。它的每一列是样本的预测分类,每一行是样本的真实分类(反过来也可以),顾名思义,它反映了分类结果的混淆程度。

bcd2b87618a34aed9433e3abe16d9a47.png

·P(Positive):代表1,表示预测为正样本;

·N(Negative):代表0,表示预测为负样本;

·T(True):代表预测正确;

·F(False):代表预测错误。

下列Positive和Negative表示模型对样本预测的结果是正样本(正例)还是负样本(负例)。True和False表示预测的结果和真实结果是否相同。

·True positives(TP)

预测为1,预测正确,即实际为1;      

·False positives(FP) 

预测为1,预测错误,即实际为0;

·False negatives(FN)

预测为0,预测错误,即实际为1;

·True negatives(TN)

预测为0,预测正确,即实际为0。

 

3.2.2 准确率

准确率(Accuracy)衡量的是分类正确的比例。

 

3.3 实验流程

3.3.1 基于MindSporeLeNet-5网络

首先导入实验需要的模块。mindspore中context模块用于配置当前执行环境,包括执行模式等特性;vision.c_transforms模块是处理图像增强的高性能模块,用于数据增强图像数据改进训练模型。同时需要设置MindSpore的执行设备和模式。

import os
import mindspore as ms
# 导入mindspore中context模块,用于配置当前执行环境,包括执行模式等特性。
import mindspore.context as context
# c_transforms模块提供常用操作,包括OneHotOp和TypeCast
import mindspore.dataset.transforms as C
# vision.c_transforms模块是处理图像增强的高性能模块,用于数据增强图像数据改进训练模型。
import mindspore.dataset.vision as CV
import numpy as np
from mindspore import nn
from mindspore.nn import Accuracy
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
import matplotlib.pyplot as plt
# 设置MindSpore的执行模式和设备
context.set_context(mode=context.GRAPH_MODE, device_target='CPU') # Ascend, CPU, GPU

 对数据进行预处理:

def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    data_train = os.path.join(data_dir, 'train') # 训练集信息
    data_test = os.path.join(data_dir, 'test') # 测试集信息
    ds = ms.dataset.MnistDataset(data_train if training else data_test)
    ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
    ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)
    return ds

MindSpore的model_zoo中提供了多种常见的模型,可以直接使用。我们构建LeNet-5模型:

class LeNet5(nn.Cell):
    def __init__(self):
        super(LeNet5, self).__init__()
        #设置卷积网络(输入输出通道数,卷积核尺寸,步长,填充方式)
        self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(400, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
    #构建网络
    def construct(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

使用MNIST数据集对上述定义的LeNet5模型进行训练。训练策略如下表所示,损失函数使用交叉熵损失函数。

batch size

number of epochs

learning rate

optimizer

32

5

0.01

Momentum 0.9

def train(data_dir, lr=0.01, momentum=0.9, num_epochs=5):
    ds_train = create_dataset(data_dir)
    ds_eval = create_dataset(data_dir, training=False)
    net = LeNet5()
    #计算softmax交叉熵。
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    #设置Momentum优化器
    opt = nn.Momentum(net.trainable_params(), lr, momentum)
    loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
    metrics = {"Accuracy": Accuracy(), "Confusion_matrix": nn.ConfusionMatrix(num_classes=10)}
    model = Model(net, loss, opt, metrics)
    model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=True)
    metrics_result = model.eval(ds_eval)
    res = metrics_result["Confusion_matrix"]
    print('Accuracy:',metrics_result["Accuracy"])
    print('Confusion_matrix:', res)
    return res

 

3.3.2 基于pytorchLeNet-5网络

首先导入实验需要的模块:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
from torchvision import datasets, transforms

接下来构建LeNet-5网络:

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1 )
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2, 2)
        x = F.max_pool2d(self.conv2(x), 2, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

训练和测试模块,采用训练策略与MindSpore保持一致,损失函数使用交叉熵损失函数。

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.nll_loss(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))

def test(model, device, test_loader,):
    model.eval()
    total_loss = 0.
    correct = 0.
    predict=[]
    true=[]
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            true.append(target.tolist())
            predict.append(pred.tolist())
            correct += pred.eq(target.view_as(pred)).sum().item()

        total_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset) * 100
        print("Test loss: {}, Accuracy: {}".format(total_loss, acc))
    predict = [n for a in predict for n in a]
    true=[n for a in true for n in a]
    return predict,true

4 实验结果及分析

4.1 实验结果

4.1.1 基于MindSporeLeNet-5网络

4.1.1.1 准确率

9aed6bdb69ca4cb1a316562c115cd6d6.png

 

可见经过5轮训练,准确率已达到98.267%。

 

4.1.1.2 混淆矩阵

混淆矩阵及其可视化效果如下:

67a990b8503b404db7c77a4355d945e9.png

03d355886a2846519dad49fa5494a2f0.png

4.1.2 基于pytorchLeNet-5网络

4.1.2.1 准确率

cf128df7f1fc4a23aefeeaa4495d8f89.png

可见经过5轮训练,准确率已超过97%。

4.1.2.2 混淆矩阵

aa37c80bc7f04514a79fcdf73a1378dd.png

4.2 结果分析与对比

用于手写数字识别(图像分类)的LeNet-5是卷积神经网络(CNN)的开山之作, AlexNet、ResNet等都是在其基础上发展而来的。CNN提出了三个创新点:局部感受野(Local Receptive Fields)、共享权重(Shared Weights)、时空下采样(Spatial or Temporal Subsampling)。

局部感受野(Local Receptive Fields):用来表示网络内部的不同位置的神经元对原图像的感受范围的大小,对应于CNN中的卷积核,可以抽取图像初级的特征,如边、转角等,这些特征会在后来的层中通过各种联合的方式来检测高级别特征。

共享权重(Shared Weights):在卷积过程中,每个卷积核所对应的窗口会以一定的步长在输入矩阵(图像)上不断滑动并进行卷积操作,最后,每个卷积核会生成一个对应的feature map(也就是卷积核的输出),一个feature map中的每一个单元都是由相同的权重(也就是对应的卷积核内的数值)计算得到的,这就是共享权重。

时空下采样(Spatial or Temporal Subsampling):对应于CNN中的池化操作,也就是从卷积得到的feature map中提取出重要的部分,此操作可以降低模型对与图像平移和扭曲的敏感程度。

下表列出了各模型的准确率对比情况:

模型

准确率

基于MindSpore的LeNet-5网络

98.267%

基于pytorch的LeNet-5网络

97.35%

可见LeNet在不同平台的手写数字识别方面均展现出了优异的性能,无愧为卷积神经网络的开山鼻祖。

 

源代码

MSLeNet手写识别.py

import os
import mindspore as ms
# 导入mindspore中context模块,用于配置当前执行环境,包括执行模式等特性。
import mindspore.context as context
# c_transforms模块提供常用操作,包括OneHotOp和TypeCast
import mindspore.dataset.transforms as C
# vision.c_transforms模块是处理图像增强的高性能模块,用于数据增强图像数据改进训练模型。
import mindspore.dataset.vision as CV
import numpy as np
from mindspore import nn
from mindspore.nn import Accuracy
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
import matplotlib.pyplot as plt
# 设置MindSpore的执行模式和设备
context.set_context(mode=context.GRAPH_MODE, device_target='CPU') # Ascend, CPU, GPU

def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    data_train = os.path.join(data_dir, 'train') # 训练集信息
    data_test = os.path.join(data_dir, 'test') # 测试集信息
    ds = ms.dataset.MnistDataset(data_train if training else data_test)
    ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
    ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)
    return ds

ds = create_dataset('D:/Dataset/MNIST', training=False)
data = ds.create_dict_iterator().get_next()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()
#显示前4张图片以及对应标签
for i in range(1, 5):
    plt.subplot(2, 2, i)
    plt.imshow(images[i][0])
    plt.title('Number: %s' % labels[i])
    plt.xticks([])
plt.show()

#定义LeNet5模型
class LeNet5(nn.Cell):
    def __init__(self):
        super(LeNet5, self).__init__()
        #设置卷积网络(输入输出通道数,卷积核尺寸,步长,填充方式)
        self.conv1 = nn.Conv2d(1, 6, 5, stride=1, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(400, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
    #构建网络
    def construct(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

def train(data_dir, lr=0.01, momentum=0.9, num_epochs=5):
    ds_train = create_dataset(data_dir)
    ds_eval = create_dataset(data_dir, training=False)
    net = LeNet5()
    #计算softmax交叉熵。
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    #设置Momentum优化器
    opt = nn.Momentum(net.trainable_params(), lr, momentum)
    loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
    metrics = {"Accuracy": Accuracy(), "Confusion_matrix": nn.ConfusionMatrix(num_classes=10)}
    model = Model(net, loss, opt, metrics)
    model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=True)
    metrics_result = model.eval(ds_eval)
    res = metrics_result["Confusion_matrix"]
    print('Accuracy:',metrics_result["Accuracy"])
    print('Confusion_matrix:', res)
    return res

# 绘制混淆矩阵
def plot_confusion_matrix(cm, title='Confusion Matrix'):
    classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    plt.figure(figsize=(12, 8), dpi=100)
    np.set_printoptions(precision=2)
    # 混淆矩阵中每格的值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        if c > 0.001:
            plt.text(x_val, y_val, "%0.2f" % (c,), color='#EE3B3B', fontsize=10, va='center', ha='center')
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes, rotation=90)
    plt.yticks(xlocations, classes)
    plt.ylabel('Actual Label')
    plt.xlabel('Predict Label')
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    plt.show()

confusion = train('D:/Dataset/MNIST')
plot_confusion_matrix(confusion,title='Confusion Matrix')

 

pytorchLeNet手写识别.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
from torchvision import datasets, transforms

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1 )
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2, 2)
        x = F.max_pool2d(self.conv2(x), 2, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.nll_loss(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # if idx % 100 == 0:
        #     print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))

def test(model, device, test_loader,):
    model.eval()
    total_loss = 0.
    correct = 0.
    predict=[]
    true=[]
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            true.append(target.tolist())
            predict.append(pred.tolist())
            correct += pred.eq(target.view_as(pred)).sum().item()

        total_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset) * 100
        print("Test loss: {}, Accuracy: {}".format(total_loss, acc))
    predict = [n for a in predict for n in a]
    true=[n for a in true for n in a]
    return predict,true

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("D:/Dataset/pytorch/", train=True, download=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))

                   ])),
    batch_size=batch_size, shuffle=True,
    num_workers=1, pin_memory=True  # True加快训练
)
test_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST("D:/Dataset/pytorch/", train=False, download=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True,
    num_workers=1, pin_memory=True
)

# 绘制混淆矩阵
def plot_confusion_matrix(cm, title='Confusion Matrix'):
    classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    plt.figure(figsize=(12, 8), dpi=100)
    np.set_printoptions(precision=2)
    # 混淆矩阵中每格的值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        if c > 0.001:
            plt.text(x_val, y_val, "%0.2f" % (c,), color='#EE3B3B', fontsize=10, va='center', ha='center')
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes, rotation=90)
    plt.yticks(xlocations, classes)
    plt.ylabel('Actual Label')
    plt.xlabel('Predict Label')
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    plt.show()

if __name__ == '__main__':
    lr = 0.01
    momentum = 0.9
    model = LeNet().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    num_eopchs = 5
    for eopch in range(num_eopchs):
        train(model, device, train_dataloader, optimizer, eopch)
        y_pre,y=test(model, device, test_dataloader)
        plot_confusion_matrix(confusion_matrix(y, y_pre), title='Confusion Matrix')

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/503886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java笔试强训 32】

🎉🎉🎉点进来你就是我的人了博主主页:🙈🙈🙈戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔🤺🤺🤺 目录 一、选择题 二、编程题 🔥淘宝网店…

初识Vue-事件 样式

目录 事件 事件修饰符 自定义事件 双向绑定 样式style class style Scoped(防止样式污染) 事件 <button click"add($event, 2)">add</button> add(event, num) {event.stopPropagation()// 防止冒泡this.count num; // this在方法里指向当前Vu…

YOLOv5改进系列(1)——添加SE注意力机制

前言 从这篇开始我们进入YOLOv5改进系列。那就先从最简单的添加注意力机制开始吧&#xff01;&#xff08;&#xffe3;︶&#xffe3;&#xff09;↗ 【YOLOv5改进系列】前期回顾&#xff1a; YOLOv5改进系列&#xff08;0&#xff09;——重要性能指标与训练结果评价及分…

Tomcat线程池扩展总结

Tomcat 线程池 前言 最近在压测接口&#xff0c;调线程池参数大小&#xff0c;感觉和预计的不一样。 看了看 tomcat 线程池源码 对比 JDK 线程池源码 源码 构造代码 这里是提交任务 可以看到 这里重写了 TaskQueue 类 可以看到 重写了offer方法 增加了这一行判断&#x…

ePWM模块-时基模块(2)

ePWM模块(2) 时基模块的使用 TBPRD:周期寄存器 (设置的时钟周期存入此,可通过阴影寄存器缓冲后写入,也可通过活动寄存器立即写入) TBCTR:时基计数变值寄存器 (时基当前所计数的值存入,用于和所设定周期值比较) TBPHS:时基相位寄存器 TBSTS:时基状态寄存器 …

Transformer 论文精读——Attention Is All You Need

https://www.bilibili.com/video/BV1pu411o7BE 摘要 序列转录模型是从一个序列生成另一个序列。Transformer 仅利用注意力机制&#xff08;attention&#xff09;&#xff0c;并且在机器翻译领域取得很好的成功。 结论 Transformer 重要贡献之一提出&#xff1a;multi-head…

C++实现二分查找(力扣题目704)

题目要求&#xff1a;给定一个n个元素的&#xff08;升序&#xff09;整型数组nums和一个目标值target&#xff0c;写一个函数搜索nums中的target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回-1 示例 输入: nums [-1,0,3,5,9,12] target 9 输出: 4 解释: 9 出现在…

Coursera—Andrew Ng机器学习—课程笔记 Lecture 3_Linear Algebra Review

3.1 矩阵和向量   参考视频: 3 - 1 - Matrices and Vectors (9 min).mkv 3.2 加法和标量乘法 参考视频: 3 - 2 - Addition and Scalar Multiplication (7 min).mkv 3.3 矩阵向量乘法 参考视频: 3 - 3 - Matrix Vector Multiplication (14 min).mkv 3.4 矩阵乘法 参考视频:…

[Pandas] 读取Excel文件

练习数据准备 demo.xlsx demo.xlsx中的工作表Sheet1显示上述的数据&#xff0c;Sheet2没有数据 我们可以使用Pandas中的read_excel()方法读取Excel格式的数据文件&#xff0c;生成DataFrame数据框进行数据分析处理 基本语法格式 import pandas as pd pd.read_excel(io, sheet…

JVM内存模型介绍

JVM(Java Virtual Machine)又被分为三大子系统&#xff0c;类加载子系统&#xff0c;运行时数据区&#xff0c;执行引擎。在这里我们主要讲解一下JVM的运行时数据区&#xff0c;也就是我们常说的JVM存储数据的内存模型。在这里提一点&#xff0c;平常我们常说内存模型&#xff…

行业趣闻 | 在施工现场“打灰”,挺好的?

房地产市场的不景气对土木行业的冲击、某某大学土木工程专业招不到人、某央企施工人员因吐槽土木行业现状而被辞退…… 面对互联网上诸多对土木行业的调侃和流言&#xff0c;许多土木工程专业的同学变得迷茫了。 这个行业的实际情况究竟是怎样的&#xff1f; 图源网络 2018年10…

Auto-GPT:揭示 ChatGPT、GPT-4 和开源 AI 之间的联系

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、什么是Auto-GPT&#xff1f;二、Auto-GPT 是如何工作的&#xff1f;三、Auto-GPT 能做什么&#xff1f;四、谁制造了 Auto-GPT&#xff1f;五、ChatGPT 或 …

DPDK抓包工具dpdk-dumpcap的使用

在进行网络开发中&#xff0c;我们经常会通过抓包来定位分析问题&#xff0c;在不使用DPDK的情况下&#xff0c;Linux系统通常用tcpdump&#xff0c;windows用wireshark&#xff0c;但是如果我们使用了DPDK来收包&#xff0c;就无法用这两个工具来抓包了。 这个时候我们需要用D…

Linux新字符设备驱动实验

1、新字符设备驱动原理 一、分配和释放设备号 使用 register_chrdev 函数注册字符设备的时候只需要给定一个主设备号即可&#xff0c;但是这样会 带来两个问题&#xff1a; ①、需要我们事先确定好哪些主设备号没有使用。 ②、会将一个主设备号下的所有次设备号都使用掉&#…

多线程(线程同步和互斥+线程安全+条件变量)

线程互斥 线程互斥&#xff1a; 任何时刻&#xff0c;保证只有一个执行流进入临界区访问临界资源&#xff0c;通常对临界资源起到保护作用 相关概念 临界资源&#xff1a; 一次仅允许一个进程使用的共享资源临界区&#xff1a; 每个线程内部&#xff0c;访问临界资源的代码&am…

信息抽取与命名实体识别:从原理到实现

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

STM32-江科大

新建工程 引入启动文件 Start中是启动文件&#xff0c;是STM32中最基本的文件&#xff0c;不需要修改&#xff0c;添加即可。 启动文件包含很多类型&#xff0c;要根据芯片型号来进行选择&#xff1a; 如果是选择超值系列&#xff0c;那就使用带 VL 的启动文件&#xff0c;…

多元统计分析-主成分分析的原理与实现

目录 一、什么是主成分分析&#xff1f; 二、主成分分析的原理 三、主成分分析的应用 四、使用sklearn实现主成分分析 五、总结 一、什么是主成分分析&#xff1f; 主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09;是一种常用的多元统计分…

Docker部署FAST OS DOCKER容器管理工具

Docker部署FAST OS DOCKER容器管理工具 一、FAST OS DOCKER介绍1. FAST OS DOCKER简介2. FAST OS DOCKER特点 二、本次实践介绍1. 本次实践简介2. 本次实践环境 三、本地环境检查1.检查Docker服务状态2. 检查Docker版本 四、下载FAST OS DOCKER镜像五、部署FAST OS DOCKER1. 创…

理解控制变量、内生变量、外生变量、工具变量

文章目录 前言一、控制变量二、内生变量、外生变量三、工具变量&#xff08;IV&#xff09; 前言 1.解释变量&#xff08;或自变量&#xff09;&#xff1a;解释变量是指作为研究对象&#xff0c;用于解释某个现象或行为模式的变量。其中有些解释变量是直接影响被解释变量的&a…