基于卷积神经网络的手写字体识别(详细笔记)

news2024/11/23 13:43:55

主要参考博客:
1、 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)
2、用PyTorch实现MNIST手写数字识别(最新,非常详细)

基于卷积神经网络的手写字体识别——目录

  • 1 前言
    • 1.1 实现效果
    • 1.2 学习背景
    • 1.3 关于MNIST手写数据集
  • 2 残差网络
    • 2.1 深度学习中的退化问题
    • 2.2 残差网络结构
    • 2.3 残差块的网络模型
  • 3 实现步骤&代码详析
    • 3.1 训练模型
      • 3.1.1 数据集选择
      • 3.1.2 网络构建
      • 3.1.3 测试网络
      • 3.1.4 训练网络
    • 3.2 手写数字识别
      • 3.2.1 图像预处理
      • 3.2.2 图像识别
  • 4 附录(完整程序,复制即可使用)
    • 4.1 hand_write_train.py
    • 4.2 IsBlackGround.py
    • 4.3 mainBoard.py
    • 4.4 MainWidget.py
    • 4.5 PaintBoard.py
    • 4.6 Pre_treatment.py
    • 4.7 predict.py
    • 4.8 predictPhoto.py

1 前言

1.1 实现效果

在这里插入图片描述
运行程序mainBoard.py,生成一个手写板,在手写板上画图即可识别大部分正常的手写体数字。(无法识别过于抽象的书写,鲁棒性不是很强。识别数字9的时候也有点问题,作者没有做数字增强)

1.2 学习背景

手写数字识别,作为机器视觉入门项目,无论是基于传统的OpenCV方法还是基于目前火热的深度学习、神经网络的方法都有这不错的训练效果。当然,这个项目也常常被作为大学/研究生阶段的课程实验。通常作为接触深度学习的第一个小项目。
作者本人也是刚刚开始接触深度学习,而对象就是手写数字识别。网上存在大量案例,就不必要自己去动手手搓神经网络,而可以借助网上成熟的项目案例,快速接触成功的项目,分析这些项目的代码学习如何去搭建一个神经网络。避免自己陷入重复造轮子而收获很少的陷阱中。
这个项目大概花费3-4天就可以让个人熟悉残差网络(ResNet)的基本框架。有助于新手了解了解神经网络的网络结构,作为日后进一步学习的垫脚石。

1.3 关于MNIST手写数据集

MNIST 数据集(手写数字数据集)是一个公开的公共数据集,任何人都可以免费获取它。目前,它已经是一个作为机器学习入门的通用性特别强的数据集之一,所以对于想要学习机器学习分类的、深度神经网络分类的、图像识别与处理的小伙伴,都可以选择MNIST数据集入门。
在这里插入图片描述
该数据集包含60 000个用于训练的示例10 000个用于测试的示例。数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图

内容概括文件名称文件大小包含内容
训练集图像train-images-idx3-ubyte.gz9.9MB60000个样本
训练集标签train-labels-idx1-ubyte.gz29KB60000个标签
测试集图像t10k-images-idx3-ubyte.gz1.6MB10000个样本
测试集标签t10k-labels-idx1-ubyte.gz5KB10000个标签

2 残差网络

2.1 深度学习中的退化问题

参考:残差网络ResNet网络原理及实现
在深度神经网络训练中,从经验来看,随着网络深度的增加,模型理论上可以取得更好的结果。但是实验却发现,深度神经网络中存在着退化问题(Degradation problem)。即深层网络的训练效果不如浅层网络。
残差网络解决了网络退化问题。

2.2 残差网络结构

ResNet中最重要的是残差学习单元:
在这里插入图片描述
对于一个堆积层结构(几层堆积而成)当输入为x时其学习到的特征记为H(x),现在我们希望其可以学习到残差F(x)=H(x)-x,这样其实原始的学习特征是F(x)+x 。当残差为0时,此时堆积层仅仅做了恒等映射,至少网络性能不会下降,实际上残差不会为0,这也会使得堆积层在输入特征基础上学习到新的特征,从而拥有更好的性能。一个残差单元的公式如下:

y = F ( x , W i ) + W s x y=F(x,{W_{i}})+W_{s}x y=F(x,Wi)+Wsx
程序中编写的网络结构:
在这里插入图片描述

残差网络ResNET的优点:

  1. 跳跃连接: ResNet引入了跳跃连接,允许网络直接传递输入到输出层,避免了梯度消失问题,使得可以训练更深的网络。在传统的CNN中没有。
  2. 残差学习:核心思想是学习残差(residual),即学习如何对输入进行微小的修改,从而得到正确的输出。这使得网络训练更加稳定,容易优化。
  3. 层数: ResNet通常比传统的CNN具有更深的层数,因为跳跃连接的引入有助于缓解深层网络的问题。
  4. 网络性能: 由于引入了残差学习和跳跃连接,ResNet在一些情况下表现更好,特别是当网络变得很深时。然而,对于一些较小的任务,传统的CNN可能仍然表现出色。

2.3 残差块的网络模型

  • 残差块Resnet_block:由多个残差模块组成的序列(Sequential),形成完整的残差块
  • 在这里插入图片描述
  • 残差模块Residuals:包含两个卷积层conv和一些批标准差层BatchNorm2d
    在这里插入图片描述
  • Conv3()的作用:使用1x1卷积来匹配通道数

3 实现步骤&代码详析

3.1 训练模型

3.1.1 数据集选择

手写数字识别经典数据集:本文数据集选择的MNIST 数据集(手写数字数据集),共含有六万张28*28的训练集手写图片和一万张28*28的测试集手写图片(二值图片)
在这里插入图片描述

3.1.2 网络构建

传统机器学习的问题与缺陷随着深度学习的发展被得到解决,深度学习也可以说是神经网络的重命名,他是建立在多层非线性的神经网络结构之上,对数据表示进行抽象的一系列机器学习。深度学习的出现使得图像,语言得到突破性的发展。
最基本的单元是神经元模型。每个神经元与其他神经元相连,当他“兴奋”时,就会向相连的神经元发送物质,改变神经元的电位。如果某个神经元的电位超过了一个阀值,那么它就会被激活。结果抽象可以得到沿用至今的M_P神经元模型。
在这里插入图片描述
线性部分,是简单的相乘相加,激活部分是利用激活函数处理得到输出。常见的激活函数有sigmoid,relu等,本次采用的激活函数是relu函数。
在这里插入图片描述
由神经元组成的多层神经网络,如图所示。有输入层,输出层以及中间隐含层。每一个输入线性求合,通过激活函数,传到下一个神经元,不必一个个的去算,使用向量化来使得程序更加简洁。
在这里插入图片描述
在这里插入图片描述
下面展示 ResNET 网络模型

	# import torch
    # from torch import nn, optim
    # import torch.nn.functional as F
    # ResNet模型
    net = nn.Sequential(
        # 顺序容器,网络结构
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),  # 二维卷积层
        nn.BatchNorm2d(64),  # 二维批标准化层
        nn.ReLU(),  # ReLUctant激活函数层
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))  # 二维最大池化层

    # 三个残差块ResNet block,增强模型的深度和复杂性
    # 和sequential的层次不同,残差块的resnet_block是自定义的
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))

    net.add_module("global_avg_pool", GlobalAvgPool2d())  # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    # 全局平均池化层
    net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(256, 10)))
    # 展平层(FlattenLayer)将特征图转换为一维向量
    # 全连接层(nn.Linear)将特征映射到10个输出类别上

ResNet是一个非常流行的深度学习架构,用于解决图像分类和其他计算机视觉任务。

下面是对以上代码的解释:

  1. 首先,第一段代码
net = nn.Sequential(
        # 顺序容器,网络结构
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),  # 二维卷积层
        nn.BatchNorm2d(64),  # 二维批标准化层
        nn.ReLU(),  # ReLUctant激活函数层
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))  # 二维最大池化层

a) nn.Sequential:这是一个顺序容器,用于将一系列的神经网络层按顺序组合在一起,构建一个整体的神经网络模型。
b) nn.Conv2d:这是一个二维卷积层。在这里,你使用一个卷积核大小为7x7,步长为2,填充为3,将单通道的输入图像转换为64通道的特征图。
c) nn.BatchNorm2d:这是一个二维批标准化层,用于规范化输入数据的均值和方差,以加速训练和增强模型的稳定性。
d) nn.ReLU:这是一个ReLU激活函数层,用于引入非线性性。
e) nn.MaxPool2d:这是一个二维最大池化层,用于降低特征图的空间分辨率。
2. 然后,通过net.add_module()添加了三个残差块(resnet_block),这些块会增加模型的深度和复杂性,以帮助提取更丰富的特征

def resnet_block(in_channels, out_channels, num_residuals, first_block=False)  
	·in_channels 	输入通道数
	·out_channels	输出通道数
	·num_residuals	残差数
	·first_block	是否为第一个残差块
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
  1. 接下来,通过添加一个全局平均池化层(GlobalAvgPool2d)将特征图压缩为大小为1x1的特征图。
    net.add_module("global_avg_pool", GlobalAvgPool2d())  # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    # 全局平均池化层
  1. 然后,通过添加一个展平层(FlattenLayer)将特征图转换为一维向量,
  2. 最后,通过一个全连接层(nn.Linear)将特征映射到10个输出类别上。
   net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(256, 10)))
   # 展平层(FlattenLayer)将特征图转换为一维向量
   # 全连接层(nn.Linear)将特征映射到10个输出类别上

3.1.3 测试网络

构建完网络后,可以随机生成一个Tensor,对ResNet网络进行前向传播,查看每层的形状。程序如下:

# 测试网络
    # 对构建的ResNet模型进行前向传播,查看每个层的输出形状的示例
    X = torch.rand((1, 1, 28, 28))
    # 创建了一个随机初始化的输入Tensor,形状为(1, 1, 28, 28)。表示一个 单通道 的 28x28 像素的图像
    for name, layer in net.named_children():
        # 通过循环遍历ResNet模型的每一层。
        X = layer(X)
        # 通过循环遍历ResNet模型的每一层。
        print(name, ' output shape:\t', X.shape)
        # 打印当前层的名字和输出X的形状。

关于torch.size(x,y,z,q)各维度解释:
第一个维度:表示批量大小(batch size),即一次性输入的样本数量。
第二个维度:表示通道数(channels),在卷积层中,这通常指的是输入通道的数量。
第三个维度:表示高度(height),特征图的高度。
第四个维度:表示宽度(width),特征图的宽度。

使用一个循环遍历ResNet模型的每一层,并通过前向传播输入数据 X 通过模型的每一层。在每一次循环迭代中,打印出当前层的名称以及输出 X 的形状

3.1.4 训练网络

网络训练最好放在GPU上,根据反复测试当迭代达到40次就基本完成收敛,使用RTX3060的GPU只需要10多分钟就能完成训练。而使用CPU来训练,可能需要2小时以上。训练网络的程序如下:

 print('============Start to Train=======================')
    # 训练 迭代40次 达到99.3%
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 检查是否可用,将神经网络放在CUDA设备(GPU)上进行训练,如果不可用则使用CPU
    lr, num_epochs = 0.001, 40
    # lr 学习率;num_epoch 训练轮数
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # 优化器:将网络的参数传递给优化器,,以便在训练过程中更新权重
    batch_size = 4000   # 设置批次大小
    net = net.to(device)    # 将网络移动到合适的设备上(GPU或CPU)

    # 以下是初始化参数
    print("training on \t", device)
    loss = torch.nn.CrossEntropyLoss() # 计算交叉熵损失的损失函数
    loop_times = round(60000 / batch_size)  # 每个训练的迭代次数
    train_acc_plot = [] # 训练准确率列表
    test_acc_plot = []  # 测试准确率列表
    loss_plot = []      # 损失值列表
    for epoch in range(num_epochs): # 轮询训练循环
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        # 总损失、总训练准确率、总样总数、批次计数和开始时间
        for i in tqdm(range(1, loop_times),position=0):
            # 迭代进度条
            x = train_images[(i - 1) * batch_size:i * batch_size]
            y = train_labels[(i - 1) * batch_size:i * batch_size]
            x = torch.unsqueeze(x, 1)  # 对齐维度 以与网络输入的维度相匹配
            X = x.to(device)    # 数据移动到设备上GPU/CPU
            y = y.to(device)    # 标签移动到设备上GPU/CPU
            y_hat = net(X)      # 前向传播计算网络的预测值
            l = loss(y_hat, y)  # 计算交叉熵损失,预测值y_hat与真实标签y之间
            optimizer.zero_grad() # 清空优化器梯度
            l.backward()        # 反向传播
            optimizer.step()    # 权重更新
            train_l_sum += l.cpu().item()   # 累计每个批次的损失-将损失值从GPU移到CPU,并提取为Python标量
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() # 累计每个批次的准确率
            n += y.shape[0]     # 更新样本总数
            batch_count += 1    # 更新批次计算
        test_acc = evaluate_accuracy(test_images, test_labels, net) # 计算测试集准确率,并打印
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
        # 保存模型到文件中
        torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f-train_acc%.4f-test_acc%.4f.pth' % (
            (epoch + 1), train_l_sum / batch_count, train_acc_sum / n, test_acc))
        print("save successfully") # 完成一轮的训练
        # append 添加,增加,附增
        test_acc_plot.append(test_acc)              # 测试准确率列表
        train_acc_plot.append(train_acc_sum / n)    # 训练准确率列表
        loss_plot.append(train_l_sum / batch_count) # 损失列表

首先进行神经网络的训练设置:

  1. 检查GPU是否可用,若可用则将神经网络放在CUDA设备上进行训练
  2. 设置学习率(lr)和训练轮数(num epochs)。
  3. 创建Adam优化器,将网络参数传递给优化器,以便在训练过程中更新权重。
  4. 设置每个批次的大小(batch_size)。
  5. 将网络移动到适当的设备(GPU或CPU)上。
  6. 初始化用于记录训练过程中准确率损失列表(train_acc_plot、test_acc_plot、loss_plot)
  7. torch.nn.CrossEntropyLoss()是一个用于计算交叉熵损失的损失函数。它将模型的预测概率分布与真实标签之间的差异作为损失,通过最小化这个损失来优化模型。
  8. loop_times是用来确定在每个训练迭代中需要循环的次数。总训练样本数(60000)除以每个批次的大小(batch_size),然后四舍五入到最接近的整数。这个值说明了在每个训练轮次中,需要多少次循环来遍历所有的训练数据。

将这两个部分结合起来,可以在每次训练迭代中,计算交叉熵损失,并将其用于反向传播和权重更新。这样,在每个训练轮次中,都会使用多次循环来逐批次地处理训练数据,最终优化模型的参数以提高准确率。

后面是关于训练轮询的解析。在每个训练轮次(epoch)中,通过多次循环遍历训练数据的批次来训练网络:

  1. train_l_sum、train_acc_sum、n、batch_count、start等变量用于累计每个轮次中的总损失、总训练准确率、总样本数、批次计数和开始时间
  2. time.time() 是 Python 中的一个函数,用于获取当前的时间戳(以秒为单位)。它可以用于测量代码执行的时间,计算时间间隔,或者在代码的不同部分插入时间戳来进行时间统计。
for i in tqdm(range(1, loop_times),position=0):#内部循环
  1. 在每个循环中,从训练集中获取一个批次的输入数据 x 和标签 y。将 x 的维度扩展,以与网络输入的维度相匹配。

  2. 将数据和标签移动到设备(GPU或CPU)上。

  3. 通过前向传播计算网络的预测值 y_hat。

  4. 计算损失l,使用 loss 对象计算预测值与真实标签之间的交叉熵损失。

  5. 清空优化器的梯度,并进行反向传播和权重更新。optimizer.step() 是 PyTorch 中用于执行梯度下降步骤的方法。在神经网络训练中,它被用于更新模型的参数(权重和偏置),以减少损失函数的值,从而优化模型的性能。每次参数更新之前,一般会先使用 optimizer.zero_grad() 来清空之前的梯度,然后进行正向传播、反向传播和参数更新

  6. 累计每个批次的损失和准确率。
    l.cpu().item():将损失值从GPU移到CPU,并提取为Python标量
    (y_hat.argmax(dim=1) == y).sum().cpu().item():
    首先通过(y_hat.argmax(dim=1) == y)计算模型的预测结果是否与真是标签相匹配。argmax(dim=1) 操作返回每个样本中最大值所在的索引,表示模型预测的类别。与真实值比较后,得到一个bool张量(Tensor,即多维数组multidimensional array),返回True或False。
    .sum():这部分代码对布尔张量进行求和操作,统计出预测正确的样本数量。
    .cpu().item():这部分代码将计算得到的张量从 GPU 移动到 CPU,并将其提取为 Python 标量,以便在代码中使用和记录。

  7. 更新样本总数和批次计数。
    ==================================================================

  8. 在每个轮次结束后,计算并打印训练损失、训练准确率和测试准确率。同时,将模型的权重保存到文件中。

  9. 将测试准确率、训练准确率和损失分别添加到用于绘制图形的列表中。

这个循环会迭代完所有的训练数据,进行参数更新,并在每个轮次结束后进行一些记录和保存。整个循环的目标是通过多次迭代优化网络的权重,使其逐渐学习数据的模式和特征,从而提高在测试数据上的准确率。在每个轮次结束时,你可以查看模型的性能,保存模型权重,并记录训练和测试准确率随着训练轮次的变化情况。

3.2 手写数字识别

3.2.1 图像预处理

图像预处理的主要步骤:
在这里插入图片描述
在这里插入图片描述

3.2.2 图像识别

将输入的图片进行预处理,完成ROI提取后,把图像输入到之前训练好的模型中进行预测即可。
在这里插入图片描述

4 附录(完整程序,复制即可使用)

在这里插入图片描述

4.1 hand_write_train.py


# author:Hurricane
# date:  2020/11/4
# E-mail:hurri_cane@qq.com

import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2 as cv
import random
import torch
from torch import nn, optim
import torch.nn.functional as F
import time
from tqdm import tqdm

## 训练集文件
train_images_idx3_ubyte_file = 'D:/CQUPT/软通团队/2023暑假/Task6:深度学习/Hand_wrtten-master/dataset/train-images.idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = 'D:/CQUPT/软通团队/2023暑假/Task6:深度学习/Hand_wrtten-master/dataset/train-labels.idx1-ubyte'
# 测试集文件
test_images_idx3_ubyte_file = 'D:/CQUPT/软通团队/2023暑假/Task6:深度学习/Hand_wrtten-master/dataset/t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = 'D:/CQUPT/软通团队/2023暑假/Task6:深度学习/Hand_wrtten-master/dataset/t10k-labels.idx1-ubyte'

# 读取数据部分
def decode_idx3_ubyte(idx3_ubyte_file, processingLabel):
# 更新偏移量,以便跳过已经处理的图像数据。
    bin_data = open(idx3_ubyte_file, 'rb').read()
    # rb:以二进制读取模式打开文件
    offset = 0
    fmt_header = '>iiii'  # 因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
    # 计算每张图片的像素数。
    # 魔术数(Magic Number)是一个固定的值或标识,通常用于识别文件格式、协议、数据结构等。魔术数在文件头部或数据的起始位置,用于表明数据的类型或格式,帮助程序识别文件内容并进行正确的处理。
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('图片数量: %d张, 图片大小: %d*%d' % (num_images, num_rows, num_cols))

    # 解析数据集
    image_size = num_rows * num_cols # 计算每张图片的像素数。
    offset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    print(offset)
    fmt_image = '>' + str(
        image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    print(fmt_image, offset, struct.calcsize(fmt_image))
    images = np.empty((num_images, 28, 28))
    # 创建一个NumPy数组,用于存储解析后的图片数据。每张图片的大小是 28x28 像素。
    # numpy.empty(shape, dtype=float, order='C'),形状,数据类型,C-行优先,F-列优先
    # plt.figure()
    for i in tqdm(range(num_images),desc=str(processingLabel),position=0):
        image = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols)).astype(np.uint8)
        # np.array(struct.unpack_from(fmt_image, bin_data, offset))
        #       从二进制数据中解包出一张图片的像素数据,存储在 image 变量中。
        # image = image.reshape((num_rows, num_cols)).astype(np.uint8)
        #       将解包的像素数据重新形状为 28x28 的图像,数据类型转换为 无符号8位 整数。
        # images[i] = cv.resize(image, (96, 96))
        images[i] = image # 将处理后的图像存储在 images 数组的相应位置
        # print(images[i])
        offset += struct.calcsize(fmt_image) # 更新偏移量,以便跳过已经处理的图像数据。

    return images


def decode_idx1_ubyte(idx1_ubyte_file, processingLabel):
    bin_data = open(idx1_ubyte_file, 'rb').read()
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('图片数量: %d张' % num_images)

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in tqdm(range(num_images),position=0,desc=str(processingLabel)):
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file, 'Train_images')


def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file, 'Train_labels')


def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    return decode_idx3_ubyte(idx_ubyte_file, 'Test_images')


def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    return decode_idx1_ubyte(idx_ubyte_file,'Test_labels')


# 构建网络部分
class Residual(nn.Module):  # 本类已保存在d2lzh_pytorch包中方便以后使用
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels) # 批标准差
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)


class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])


def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    # num_residuals:残差数
    if first_block:
        assert in_channels == out_channels  # 第一个模块的通道数同输入通道数一致
    blk = [] # 空序列
    for i in range(num_residuals):
        # i从0开始循环,知道num_residuals-1
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
            # append:增加,附加
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk) # 根据构建好的残差模块序列blk,通过sequential转化成一个完整的序列表示整个残差块


def evaluate_accuracy(img, label, net):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        X = torch.unsqueeze(img, 1)
        if isinstance(net, torch.nn.Module):
            net.eval()  # 评估模式, 这会关闭dropout
            acc_sum += (net(X.to(device)).argmax(dim=1) == label.to(device)).float().sum().cpu().item()
            net.train()  # 改回训练模式
        else:  # 自定义的模型, 3.13节之后不会用到, 不考虑GPU
            if ('is_training' in net.__code__.co_varnames):  # 如果有is_training这个参数
                # 将is_training设置成False
                acc_sum += (net(X, is_training=False).argmax(dim=1) == label).float().sum().item()
            else:
                acc_sum += (net(X).argmax(dim=1) == label).float().sum().item()
        n += label.shape[0]
    return acc_sum / n

class FlattenLayer(torch.nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)

if __name__ == '__main__':
    print("train:")
    train_images_org = load_train_images().astype(np.float32)
    train_labels_org = load_train_labels().astype(np.int64)
    print("test")
    test_images = load_test_images().astype(np.float32)[0:1000]
    test_labels = load_test_labels().astype(np.int64)[0:1000]
    # 数据转换为Tensor
    train_images = torch.from_numpy(train_images_org)
    train_labels = torch.from_numpy(train_labels_org)
    test_images = torch.from_numpy(test_images)
    test_labels = torch.from_numpy(test_labels)
    # test_images = load_test_images()
    # test_labels = load_test_labels()

    # 查看前十个数据及其标签以读取是否正确
    for i in range(10):
        j = random.randint(0, 60000) # 0-60000之间产生一个随机数
        print("now, show the number of image[{}]:".format(j), int(train_labels_org[j]))
        # 使用了字符串的.format()方法来插入一个占位符 {}
        # 将train_labels_org中第j个位置的元素(标签)转换为整数
        # 如:now, show the number of image[39430]: 6
        # 第39430张照片是6
        img = train_images_org[j]
        img = cv.resize(img, (600, 600))
        cv.imshow("image", img)
        cv.waitKey(0) # 按空格
    cv.destroyAllWindows()
    print('all done!')
    print("*" * 50)

    # import torch
    # from torch import nn, optim
    # import torch.nn.functional as F
    # ResNet模型
    net = nn.Sequential(
        # 顺序容器,网络结构
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),  # 二维卷积层
        nn.BatchNorm2d(64),  # 二维批标准化层
        nn.ReLU(),  # ReLUctant激活函数层
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))  # 二维最大池化层

    # 三个残差块ResNet block,增强模型的深度和复杂性
    # 和sequential的层次不同,残差块的resnet_block是自定义的
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))

    net.add_module("global_avg_pool", GlobalAvgPool2d())  # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    # 全局平均池化层
    net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(256, 10)))
    # 展平层(FlattenLayer)将特征图转换为一维向量
    # 全连接层(nn.Linear)将特征映射到10个输出类别上

    # 测试网络
    # 对构建的ResNet模型进行前向传播,查看每个层的输出形状的示例
    X = torch.rand((1, 1, 28, 28))
    # 创建了一个随机初始化的输入Tensor,形状为(1, 1, 28, 28)。表示一个 单通道 的 28x28 像素的图像
    for name, layer in net.named_children():
        # 通过循环遍历ResNet模型的每一层。
        X = layer(X)
        # 通过循环遍历ResNet模型的每一层。
        print(name, ' output shape:\t', X.shape)
        # 打印当前层的名字和输出X的形状。
    print('*' * 50)
    print('============Start to Train=======================')
    # 训练 迭代40次 达到99.3%
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 检查是否可用,将神经网络放在CUDA设备(GPU)上进行训练,如果不可用则使用CPU
    lr, num_epochs = 0.001, 40
    # lr 学习率;num_epoch 训练轮数
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # 优化器:将网络的参数传递给优化器,,以便在训练过程中更新权重
    batch_size = 4000   # 设置批次大小
    net = net.to(device)    # 将网络移动到合适的设备上(GPU或CPU)

    # 以下是初始化参数
    print("training on \t", device)
    loss = torch.nn.CrossEntropyLoss() # 计算交叉熵损失的损失函数
    loop_times = round(60000 / batch_size)  # 每个训练的迭代次数
    train_acc_plot = [] # 训练准确率列表
    test_acc_plot = []  # 测试准确率列表
    loss_plot = []      # 损失值列表
    for epoch in range(num_epochs): # 轮询训练循环
        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        # 总损失、总训练准确率、总样总数、批次计数和开始时间
        for i in tqdm(range(1, loop_times),position=0):
            # 迭代进度条
            x = train_images[(i - 1) * batch_size:i * batch_size]
            y = train_labels[(i - 1) * batch_size:i * batch_size]
            x = torch.unsqueeze(x, 1)  # 对齐维度 以与网络输入的维度相匹配
            X = x.to(device)    # 数据移动到设备上GPU/CPU
            y = y.to(device)    # 标签移动到设备上GPU/CPU
            y_hat = net(X)      # 前向传播计算网络的预测值
            l = loss(y_hat, y)  # 计算交叉熵损失,预测值y_hat与真实标签y之间
            optimizer.zero_grad() # 清空优化器梯度
            l.backward()        # 反向传播
            optimizer.step()    # 权重更新
            train_l_sum += l.cpu().item()   # 累计每个批次的损失-将损失值从GPU移到CPU,并提取为Python标量
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() # 累计每个批次的准确率
            n += y.shape[0]     # 更新样本总数
            batch_count += 1    # 更新批次计算
        test_acc = evaluate_accuracy(test_images, test_labels, net) # 计算测试集准确率,并打印
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
        # 保存模型到文件中
        torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f-train_acc%.4f-test_acc%.4f.pth' % (
            (epoch + 1), train_l_sum / batch_count, train_acc_sum / n, test_acc))
        print("save successfully") # 完成一轮的训练
        # append 添加,增加,附增
        test_acc_plot.append(test_acc)              # 测试准确率列表
        train_acc_plot.append(train_acc_sum / n)    # 训练准确率列表
        loss_plot.append(train_l_sum / batch_count) # 损失列表

    x = range(0,40)
    # 绘制训练和测试准确率、损失随训练轮次变化的图像
    plt.plot(x, test_acc_plot, 'r', label='Test Accuracy')
    plt.plot(x, train_acc_plot, 'g', label='Train Accuracy')
    plt.plot(x, loss_plot, 'b', label='Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title('Training and Test Metrics')
    plt.legend()
    plt.show()  # 显示绘制的图像
    print("*" * 50)

4.2 IsBlackGround.py

from PIL import Image

def is_black_background(image):
    # 获取图片的像素数据
    pixels = image.getdata()
    width, height = image.size

    # 统计黑色和白色像素的数量
    black_count = 0
    white_count = 0

    for pixel in pixels:
        # 判断是否为黑色像素
        if pixel[0] < 128:  # 假设RGB值中的R通道代表亮度,小于128认为是黑色
            black_count += 1
        # 判断是否为白色像素
        if pixel[0] > 192:  # 假设RGB值中的R通道代表亮度,大于192认为是白色
            white_count += 1

    # 判断是否为黑底白字
    return black_count > white_count

def invert_image(image):
    # 反相处理图片
    inverted_image = Image.eval(image, lambda x: 255 - x)
    return inverted_image

4.3 mainBoard.py

# 加载库
from MainWidget import MainWidget
from PyQt5.QtWidgets import QApplication
import sys

def main():
    app = QApplication(sys.argv)

    mainWidget = MainWidget()  # 新建一个主界面
    mainWidget.show()  # 显示主界面

    exit(app.exec_())  # 进入消息循环


if __name__ == '__main__':
    main()

4.4 MainWidget.py

'''
Created on 2018年8月8日

@author: Freedom
'''
from PyQt5.Qt import QWidget, QColor, QPixmap, QIcon, QSize, QCheckBox
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton, QSplitter, \
    QComboBox, QLabel, QSpinBox, QFileDialog,QTextEdit
from PaintBoard import PaintBoard
import numpy as np
from PIL import Image

import cv2 as cv
from Pre_treatment import get_number as g_n # .py文件 图形预处理
import predict as pt
from time import time
from Pre_treatment import softmax # .py文件 二分类函数
net = pt.get_net()

class MainWidget(QWidget):

    def __init__(self, Parent=None):
        '''
        Constructor
        '''
        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):
        '''
                  初始化成员变量
        '''
        self.__paintBoard = PaintBoard(self)
        # 获取颜色列表(字符串类型)
        self.__colorList = QColor.colorNames()

    def __InitView(self):
        '''
                  初始化界面
        '''
        self.setFixedSize(640, 480)
        self.setWindowTitle("PaintBoard Example PyQt5")

        # 新建一个水平布局作为本窗体的主布局
        main_layout = QHBoxLayout(self)
        # 设置主布局内边距以及控件间距为10px
        main_layout.setSpacing(10)

        # 在主界面左侧放置画板
        main_layout.addWidget(self.__paintBoard)

        # 新建垂直子布局用于放置按键
        sub_layout = QVBoxLayout()

        # 设置此子布局和内部控件的间距为10px
        sub_layout.setContentsMargins(10, 10, 10, 10)

        self.__btn_Clear = QPushButton("清空画板")
        self.__btn_Clear.setParent(self)  # 设置父对象为本界面
        # 将按键按下信号与画板清空函数相关联
        self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
        sub_layout.addWidget(self.__btn_Clear)

        self.__btn_yuce = QPushButton("智能预测")
        self.__btn_yuce.setParent(self)  # 设置父对象为本界面
        self.__btn_yuce.clicked.connect(lambda:self.yuce())
        sub_layout.addWidget(self.__btn_yuce)


        self.__text_out = QTextEdit(self)
        self.__text_out.setParent(self)
        self.__text_out.setObjectName("预测结果为:")
        sub_layout.addWidget(self.__text_out)


        self.__btn_Quit = QPushButton("退出")
        self.__btn_Quit.setParent(self)  # 设置父对象为本界面
        self.__btn_Quit.clicked.connect(self.Quit)
        sub_layout.addWidget(self.__btn_Quit)

        self.__btn_Save = QPushButton("保存作品")
        self.__btn_Save.setParent(self)
        self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
        sub_layout.addWidget(self.__btn_Save)

        self.__cbtn_Eraser = QCheckBox("  使用橡皮擦")
        self.__cbtn_Eraser.setParent(self)
        self.__cbtn_Eraser.clicked.connect(self.on_cbtn_Eraser_clicked)
        sub_layout.addWidget(self.__cbtn_Eraser)

        splitter = QSplitter(self)  # 占位符
        sub_layout.addWidget(splitter)

        self.__label_penThickness = QLabel(self)
        self.__label_penThickness.setText("画笔粗细")
        self.__label_penThickness.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penThickness)

        self.__spinBox_penThickness = QSpinBox(self)
        self.__spinBox_penThickness.setMaximum(40)
        self.__spinBox_penThickness.setMinimum(2)
        self.__spinBox_penThickness.setValue(20)  # 默认粗细为10
        self.__spinBox_penThickness.setSingleStep(2)  # 最小变化值为2
        self.__spinBox_penThickness.valueChanged.connect(
            self.on_PenThicknessChange)  # 关联spinBox值变化信号和函数on_PenThicknessChange
        sub_layout.addWidget(self.__spinBox_penThickness)

        self.__label_penColor = QLabel(self)
        self.__label_penColor.setText("画笔颜色")
        self.__label_penColor.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penColor)

        self.__comboBox_penColor = QComboBox(self)
        self.__fillColorList(self.__comboBox_penColor)  # 用各种颜色填充下拉列表
        self.__comboBox_penColor.currentIndexChanged.connect(
            self.on_PenColorChange)  # 关联下拉列表的当前索引变更信号与函数on_PenColorChange
        sub_layout.addWidget(self.__comboBox_penColor)

        main_layout.addLayout(sub_layout)  # 将子布局加入主布局

    def __fillColorList(self, comboBox):

        index_black = 0
        index = 0
        for color in self.__colorList:
            if color == "black":
                index_black = index
            index += 1
            pix = QPixmap(70, 20)
            pix.fill(QColor(color))
            comboBox.addItem(QIcon(pix), None)
            comboBox.setIconSize(QSize(70, 20))
            comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)

        comboBox.setCurrentIndex(index_black)

    def on_PenColorChange(self):
        color_index = self.__comboBox_penColor.currentIndex()
        color_str = self.__colorList[color_index]
        self.__paintBoard.ChangePenColor(color_str)

    def on_PenThicknessChange(self):
        penThickness = self.__spinBox_penThickness.value()
        self.__paintBoard.ChangePenThickness(penThickness)

    def on_btn_Save_Clicked(self):
        savePath = QFileDialog.getSaveFileName(self, 'Save Your Paint', '.\\', '*.png')
        print(savePath)
        if savePath[0] == "":
            print("Save cancel")
            return
        image = self.__paintBoard.GetContentAsQImage()
        image.save(savePath[0])

    def on_cbtn_Eraser_clicked(self):
        if self.__cbtn_Eraser.isChecked():
            self.__paintBoard.EraserMode = True  # 进入橡皮擦模式
        else:
            self.__paintBoard.EraserMode = False  # 退出橡皮擦模式

    def Quit(self):
        self.close()

    def yuce(self):
        # #标准化图片 获取Y
        savePath = "./image_rgzn/test.png"
        image = self.__paintBoard.GetContentAsQImage()
        image.save(savePath)
        #img_path = Image.open(savePath)
        # start
        img = cv.imread(savePath)
        img_bw = g_n(img)
        # 下面这一段就是ROI提取
        img_bw_c = img_bw.sum(axis=1) / 255
        img_bw_r = img_bw.sum(axis=0) / 255
        r_ind, c_ind = [], []
        for k, r in enumerate(img_bw_r):
            if r >= 5:
                r_ind.append(k)
        for k, c in enumerate(img_bw_c):
            if c >= 5:
                c_ind.append(k)
        img_bw_sg = img_bw[c_ind[0]:c_ind[-1], r_ind[0]:r_ind[-1]]
        leng_c = len(c_ind)
        leng_r = len(r_ind)
        side_len = leng_c + 20
        add_r = int((side_len - leng_r) / 2)
        img_bw_sg_bord = cv.copyMakeBorder(img_bw_sg, 10, 10, add_r, add_r, cv.BORDER_CONSTANT, value=[0, 0, 0])
        # 展示图片
        #cv.imshow("img", img_bw)
        #print('output img')
        # cv.imshow("img_sg", img_bw_sg_bord)
        #print('output img_sg')
        # = cv.waitKey(1) & 0xff

        img_in = cv.resize(img_bw_sg_bord, (28, 28))  # 重构成28*28的大小
        cv.imshow("28*28",img_in)
        result_org = pt.predict(img_in, net)
        #print('over predict')
        result = softmax(result_org)
        #print('softmax')
        best_result = result.argmax(dim=1).item()
        #print('result.argmax:',best_result)
        best_result_num = max(max(result)).cpu().detach().numpy()
        #print('best_result_num')
        if best_result_num <= 0.5:
            best_result = None
        # end
        self.__text_out.setText(str(best_result))
        print("predict result is :",best_result)
        # print("hello")
        # res = QMessageBox.information(self,"人工智能判断为:",str(p),QMessageBox.Yes|QMessageBox.No)
        # res.exec_()
        # 读取数据权重

        # 预测并输出

4.5 PaintBoard.py

'''
Created on 2018年8月9日

@author: Freedom
'''
from PyQt5.QtWidgets import QWidget
from PyQt5.Qt import QPixmap, QPainter, QPoint, QPaintEvent, QMouseEvent, QPen, \
    QColor, QSize
from PyQt5.QtCore import Qt


class PaintBoard(QWidget):

    def __init__(self, Parent=None):
        '''
        Constructor
        '''
        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):

        self.__size = QSize(280, 280)

        # 新建QPixmap作为画板,尺寸为__size
        self.__board = QPixmap(self.__size)
        self.__board.fill(Qt.black)  # 用白色填充画板

        self.__IsEmpty = True  # 默认为空画板
        self.EraserMode = False  # 默认为禁用橡皮擦模式

        self.__lastPos = QPoint(0, 0)  # 上一次鼠标位置
        self.__currentPos = QPoint(0, 0)  # 当前的鼠标位置

        self.__painter = QPainter()  # 新建绘图工具

        self.__thickness = 10  # 默认画笔粗细为10px
        self.__penColor = QColor("white")  # 设置默认画笔颜色为黑色
        self.__colorList = QColor.colorNames()  # 获取颜色列表

    def __InitView(self):
        # 设置界面的尺寸为__size
        self.setFixedSize(self.__size)

    def Clear(self):
        # 清空画板
        self.__board.fill(Qt.black)
        self.update()
        self.__IsEmpty = True

    def ChangePenColor(self, color="black"):
        # 改变画笔颜色
        self.__penColor = QColor(color)

    def ChangePenThickness(self, thickness=10):
        # 改变画笔粗细
        self.__thickness = thickness

    def IsEmpty(self):
        # 返回画板是否为空
        return self.__IsEmpty

    def GetContentAsQImage(self):
        # 获取画板内容(返回QImage)
        image = self.__board.toImage()
        return image

    def paintEvent(self, paintEvent):
        # 绘图事件
        # 绘图时必须使用QPainter的实例,此处为__painter
        # 绘图在begin()函数与end()函数间进行
        # begin(param)的参数要指定绘图设备,即把图画在哪里
        # drawPixmap用于绘制QPixmap类型的对象
        self.__painter.begin(self)
        # 0,0为绘图的左上角起点的坐标,__board即要绘制的图
        self.__painter.drawPixmap(0, 0, self.__board)
        self.__painter.end()

    def mousePressEvent(self, mouseEvent):
        # 鼠标按下时,获取鼠标的当前位置保存为上一次位置
        self.__currentPos = mouseEvent.pos()
        self.__lastPos = self.__currentPos

    def mouseMoveEvent(self, mouseEvent):
        # 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
        self.__currentPos = mouseEvent.pos()
        self.__painter.begin(self.__board)

        if self.EraserMode == False:
            # 非橡皮擦模式
            self.__painter.setPen(QPen(self.__penColor, self.__thickness))  # 设置画笔颜色,粗细
        else:
            # 橡皮擦模式下画笔为纯白色,粗细为10
            self.__painter.setPen(QPen(Qt.white, 10))

        # 画线
        self.__painter.drawLine(self.__lastPos, self.__currentPos)
        self.__painter.end()
        self.__lastPos = self.__currentPos

        self.update()  # 更新显示

    def mouseReleaseEvent(self, mouseEvent):
        self.__IsEmpty = False  # 画板不再为空

4.6 Pre_treatment.py

# author:Hurricane
# date:  2020/11/6
# E-mail:hurri_cane@qq.com


import cv2 as cv
import numpy as np
import os
from PIL import Image


def get_number(img):
    # 图像预处理
    # cv.imshow("beform precondition", img)
    img_gray = cv.cvtColor(img, cv.COLOR_RGB2GRAY) # 灰度化
    # cv.imshow("Gray",img_gray)
    img_gray_resize = cv.resize(img_gray, (600, 600)) #重定义成600*600的灰度化图形
    # cv.imshow("600*600",img_gray_resize)
    ret, img_bw = cv.threshold(img_gray_resize, 200, 255, cv.THRESH_BINARY) # 二值化
    # cv.imshow("Binary",img_bw)
    # 判断是否图片是白底黑字的照片
    white_count = np.sum(img_bw == 255)
    black_count = np.sum(img_bw == 0)

    if white_count > black_count:
        # 白底黑字,进行反相处理
        img_bw = cv.bitwise_not(img_bw)
    # cv.imshow('white background', img_bw)
    # cv.waitKey(1000)
    kernel = np.ones((3, 3), np.uint8) # 设置核大小
    # img_open = cv.morphologyEx(img_bw,cv.MORPH_CLOSE,kernel)
    img_open = cv.dilate(img_bw, kernel, iterations=2) # 膨胀函数
    # cv.imshow("dilate", img_open)
    num_labels, labels, stats, centroids = \
        cv.connectedComponentsWithStats(img_open, connectivity=8, ltype=None)
    # OpenCV库中的connectedComponentsWithStats函数
    # 该函数用于在二值图像中找到连通组件(连通区域),并返回每个连通组件的一些统计信息。
    # 传入参数:  img_open        二值化后的图像
    #           connectivity=8  表示使用8邻域连接
    #           ltype=None      表示返回的标签图像数据类型与输入图像相同
    # 返回值:  num_labels  表示图像中的连通组件数量,
    #          labels   是一个与输入图像大小相同的标签图像,其中每个像素值代表对应像素所属的连通组件标签
    #          stats    是一个NumPy数组,其中包含每个连通组件的统计信息(如面积、边界框等)
    #          centroids是一个NumPy数组,包含每个连通组件的中心坐标。
    for sta in stats:
        # 循环遍历,stats中的每一项sta表示一个连通组件的统计信息。
        if sta[4] < 1000:
            # 检查当前连通组件的面积(存储在sta的第5个元素中,由于Python的索引从0开始,所以使用sta[4]表示面积)是否小于1000。
            cv.rectangle(img_open, tuple(sta[0:2]), tuple(sta[0:2] + sta[2:4]), (0, 0, 255), thickness=-1)
            # 使用OpenCV的rectangle函数在img_open上绘制一个红色的矩形框,
            # 将该连通组件的区域填充为红色。
            # 这样就可以将面积较小的连通组件标记出来,方便后续处理或分析。
    return img_open

def get_roi(img_bw):
    # 图像ROI,感兴趣位置提取办法
    img_bw_c = img_bw.sum(axis=1) / 255 # 各列方向上的像素求和并除去255,
    img_bw_r = img_bw.sum(axis=0) / 255 # 各行方向上的像素求和并除去255,
    all_sum = img_bw_c.sum(axis=0) # 全部像素求和
    if all_sum != 0: # 不为全黑
        r_ind, c_ind = [], []
        for k, r in enumerate(img_bw_r): # 将满足像素和大于等于5的行的索引添加到r_ind列表中
            if r >= 5:
                r_ind.append(k)
        for k, c in enumerate(img_bw_c): # 将满足像素和大于等于5的列的索引添加到c_ind列表中
            if c >= 5:
                c_ind.append(k)
        if len(r_ind)==0 or len(c_ind)==0: #全黑的行或列则返回原图形
            return img_bw
        img_bw_sg = img_bw[c_ind[0]:c_ind[-1], r_ind[0]:r_ind[-1]]
        # 基于r_ind和c_ind的索引,裁剪图像,提取出需要保留的区域
        leng_c = len(c_ind) # 计算保留区域的边长,长
        leng_r = len(r_ind) # 计算保留区域的边长,宽
        side_len = max(leng_c, leng_r) + 20 # 计算边长的基础上进行拓展
        # 计算需要在保留区域周围添加的像素行数和列数。
        # 确保保留区域的最终尺寸为正方形,且大小为side_len x side_len。
        if leng_c == side_len:
            add_r = int((side_len - leng_r) / 2)
            add_c = 10
        else:
            add_r = 10
            add_c = int((side_len - leng_c) / 2)
        # 使用OpenCV的cv.copyMakeBorder函数,在保留区域周围添加一圈像素,从而扩展保留区域的尺寸。
        # 添加的像素值为[0, 0, 0],即黑色。
        img_bw_sg_bord = cv.copyMakeBorder(img_bw_sg, add_c, add_c, add_r, add_r, cv.BORDER_CONSTANT, value=[0, 0, 0])
        return img_bw_sg_bord
    else:
        return img_bw

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition

4.7 predict.py

# author:Hurricane
# date:  2020/11/5
# E-mail:hurri_cane@qq.com
# -------------------------------------#
#       对单张图片进行预测
# -------------------------------------#
import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2 as cv
import random
import torch
from torch import nn, optim
import torch.nn.functional as F




class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)


class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])


def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    # num_residuals:残差数
    if first_block:
        assert in_channels == out_channels  # 第一个模块的通道数同输入通道数一致
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

class FlattenLayer(torch.nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)
def get_net():
    # 构建网络
    # ResNet模型
    model_path = r"D:\CQUPT\软通团队\2023暑假\Task6:深度学习\Hand_wrtten-master\logs\Epoch100-Loss0.0000-train_acc1.0000-test_acc0.9930.pth"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))

    net.add_module("global_avg_pool", GlobalAvgPool2d())  # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(256, 10)))

    # 测试网络
    # X = torch.rand((1, 1, 28, 28))
    # for name, layer in net.named_children():
    #     X = layer(X)
    #     print(name, ' output shape:\t', X.shape)

    # 加载网络模型
    print("Load weight into state dict...")
    stat_dict = torch.load(model_path, map_location=device)
    net.load_state_dict(stat_dict)
    net.to(device)
    net.eval()
    print("Load finish!")
    return net


def predict(img, net):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    img_in = torch.from_numpy(img)
    img_in = torch.unsqueeze(img_in, 0)
    img_in = torch.unsqueeze(img_in, 0).to(device)
    img_in = img_in.float()
    result_org = net(img_in)
    return result_org

4.8 predictPhoto.py

# author:Hurricane
# date:  2020/11/6
# E-mail:hurri_cane@qq.com

import cv2 as cv
import numpy as np
import os
from Pre_treatment import get_number as g_n # .py文件 图形预处理
import predict as pt
from time import time
from Pre_treatment import softmax # .py文件 二分类函数

net = pt.get_net()
# 原始图片路径
#orig_path = r"real_img_resize"
#img_list = os.listdir(orig_path)

img_path = r'caise7.png'
# img = cv.imread(img_path)

since = time()
#img_path = os.path.join(orig_path, img_name)
# 输出拼接后的路径
#print(img_path)
img = cv.imread(img_path)
img_bw = g_n(img)
# 下面这一段就是ROI提取
img_bw_c = img_bw.sum(axis=1) / 255
img_bw_r = img_bw.sum(axis=0) / 255
r_ind, c_ind = [], []
for k, r in enumerate(img_bw_r):
    if r >= 5:
        r_ind.append(k)
for k, c in enumerate(img_bw_c):
    if c >= 5:
        c_ind.append(k)
img_bw_sg = img_bw[ c_ind[0]:c_ind[-1] ,r_ind[0]:r_ind[-1]]
leng_c = len(c_ind)
leng_r = len(r_ind)
side_len = leng_c + 20
add_r = int((side_len-leng_r)/2)
img_bw_sg_bord = cv.copyMakeBorder(img_bw_sg,10,10,add_r,add_r,cv.BORDER_CONSTANT,value=[0,0,0])
# 展示图片
cv.imshow("img", img_bw)
cv.imshow("img_sg", img_bw_sg_bord)
# = cv.waitKey(1) & 0xff

img_in = cv.resize(img_bw_sg_bord, (28, 28)) # 重构成28*28的大小
result_org = pt.predict(img_in,  net)
result = softmax(result_org)
best_result = result.argmax(dim=1).item()
best_result_num = max(max(result)).cpu().detach().numpy()
if best_result_num <= 0.5:
    best_result = None

# 显示结果
img_show = cv.resize(img, (600, 600))
end_predict = time()
fps = np.ceil(1 / (end_predict - since))
font = cv.FONT_HERSHEY_SIMPLEX
cv.putText(img_show, "The number is:" + str(best_result), (1, 30), font, 1, (0, 0, 255), 2)
cv.putText(img_show, "Probability is:" + str(best_result_num), (1, 60), font, 1, (0, 255, 0), 2)
cv.putText(img_show, "FPS:" + str(fps), (1, 90), font, 1, (255, 0, 0), 2)
cv.imshow("result", img_show)
cv.waitKey(1)
print(result)
print("*" * 50)
print("The number is:", best_result)
cv.waitKey(0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1010142.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国产洗碗机打响超越战

“征服世界的将是这样一些人&#xff1a;开始的时候&#xff0c;他们试图找到梦想中的乐园。最终&#xff0c;当他们无法找到时&#xff0c;就亲自创造了它。”诺贝尔文学奖获得者萧伯纳的这句话&#xff0c;适用于许多中国行业和企业&#xff0c;洗碗机就是其中之一。 对热爱…

十进制小数转换为单双精度浮点数方法

1 将十进制小数转换为单精度浮点数的方法如下&#xff1a; 2. 将十进制小数转换为双精度浮点数的方法如下&#xff1a; 和单精度浮点值转换一样

前端Layui框架介绍

当涉及到前端UI框架时&#xff0c;Layui&#xff08;简称layui&#xff09;是一个备受欢迎的框架之一。在这篇博客中&#xff0c;我们将深入了解layui&#xff0c;包括其市场占有率、开发语言、使用场景、框架特点以及一些使用案例。 1. 市场占有率 Layui 是一款流行的前端UI框…

(纯干货建议收藏)大型字符串模拟-超强超全函数技巧总结

这篇文章将会总结一些处理字符串、进制转换等等的常见的、非常有用的技巧和函数。后续会随时更新本文章&#xff0c;希望大家收藏、留言&#xff0c;一起学习进步&#xff01; 对于特别简单的函数&#xff0c;就不写函数的详细原型啦&#xff01; 具体包含四部分&#xff0c;…

Xilinx FPGA未使用管脚上下拉状态配置(ISE和Vivado环境)

文章目录 ISE开发环境Vivado开发环境方式1&#xff1a;XDC文件约束方式2&#xff1a;生成选项配置 ISE开发环境 ISE开发环境&#xff0c;可在如下Bit流文件生成选项中配置。 右键点击Generate Programming File&#xff0c;选择Process Properties&#xff0c; 在弹出的窗口选…

《程序员职场工具库》如何优化你的工作 —— PDCA 循环

PDCA 循环简介 PDCA 循环是一种以持续改进为核心思想的管理方法&#xff0c;在全球各个领域得到广泛的应用。它还有好几个别称&#xff0c;叫“质量环”&#xff0c;也叫“戴明环”&#xff0c;也有叫“持续改进螺旋”。 PDCA 循环由四个步骤组成&#xff1a; 计划&#xff…

基于SSM+Vue的中国咖啡文化宣传网站

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用vUE技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

vector模拟实现——关于模拟中的易错点

前言 vector 本质上类似数组&#xff0c;也可以理解为一种泛型的 string。string 只能存储 char 类型&#xff0c;但是 vector 支持各种内置类型和自定义类型。本次将围绕模拟实现 vector 中遇到的问题进行分析。 文章目录 前言一、确定思路二、实现过程2.1 查阅文档2.2 验证…

4-3 nn.functional和nn.Module

一&#xff0c;nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数&#xff0c;模型层&#xff0c;损失函数)。 其实&#xff1a;Pytorch和神经网络相关的功能组件大…

中小企业数字化转型难?为什么不试试“企业级”无代码平台

首先&#xff0c;让我们思考一下&#xff0c;中小企业为什么要进行数字化转型&#xff1f;随着全球经济的数字化趋势日益明显&#xff0c;中小企业作为经济的重要组成部分&#xff0c;其数字化转型已成为推动经济高质量发展的关键。数字技术可以帮助中小企业提高生产效率、降低…

ctfshow-web-红包题 辟邪剑谱

0x00 前言 CTF 加解密合集CTF Web合集网络安全知识库溯源相关 文中工具皆可关注 皓月当空w 公众号 发送关键字 工具 获取 0x01 题目 0x02 Write Up 这道题主要是考察mysql查询绕过的问题。 首先访问后看到是一个登录页面&#xff0c;测试注册等无果 扫描目录&#xff0c;发…

Packet Tracer的使用介绍

直接访问 Packet Tracer 的帮助页面、教程视频和在线资源对于了解该软件会更加方便。 单击菜单工具栏右上角的问号图标。单击“帮助”菜单&#xff0c;然后选择“内容”。 b. 通过单击“帮助”>“教程”来访问 Packet Tracer 的教程视频。 菜单栏&#xff1a;提供文件、编辑…

SpringBoot运行原理

目录 SpringBootApplication ComponentScan SpringBootConfiguration EnableAutoConfiguration 结论 SpringbootApplication&#xff08;主入口&#xff09; SpringBootApplication public class SpringbootConfigApplication {public static void main(String[] args) {…

Android动态片段

之前创建的片段都是静态的。一旦显示片段&#xff0c;片段的内容就不能改变了。尽管可以用一个新实例完全取代所显示的片段&#xff0c;但是并不能更新片段本身的内容。 之前已经创建过一个基础秒表应用&#xff0c;具体代码https://github.com/MADMAX110/Stopwatch。我们将这个…

发生以下的报错怎么办?

报错问题&#xff1a; 解决办法&#xff1a; 根据你提供的代码和错误信息&#xff0c;问题出在使用了nullptr。这个错误是因为你的编译器不支持C11标准。 nullptr是C11引入的空指针常量。为了解决这个问题&#xff0c;你可以尝试以下两种方法之一&#xff1a; 1. 将nullptr…

想要精通算法和SQL的成长之路 - 可以攻击国王的皇后

想要精通算法和SQL的成长之路 - 可以攻击国王的皇后 前言一. 可以攻击国王的皇后 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 可以攻击国王的皇后 原题链接 这个题目其实并没有涉及到什么很难的算法&#xff0c;其实就是一个简单的遍历题目。核心思想&#xff1a; 以…

CRM系统销售自动化功能如何提高销售效率

销售效率对企业的盈利能力有着至关重要的联系。提高销售效率&#xff0c;就是要提高销售人员的工作效率和销售转化率。那么&#xff0c;企业如何提高销售效率呢&#xff1f;CRM销售自动化功能可以帮助企业实现这一目标。 一、线索管理 线索是指有潜在购买意向的客户&#xff…

kali必杀器之三剑客

Kali常见攻击手段 注意:仅用于教程和科普&#xff0c;切勿做违法之事&#xff0c;否则后果自负 1 网络攻击手段 请正确使用DDos和CC攻击&#xff0c;不要用来做违反当地法律法规的事情&#xff0c;否则后果自负 使用之前kali需要能够上网 参考:kali安装 1.1 DDos攻击…

新加坡打车软件平台Ryde Group申请1700万美元纳斯达克IPO上市

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经获悉&#xff0c;新加坡打车软件平台Ryde Group近期已向美国证券交易委员会&#xff08;SEC&#xff09;提交招股书&#xff0c;申请在纳斯达克IPO上市&#xff0c;股票代码为&#xff08;RYDE&#xff09;&#x…

学习javaEE初阶的第一堂课

学习金字塔 java发展简史 Java最初诞生的时候是用来写前端的!! 199x年 199x年,互联网还处在比较早期的阶段,当时主流的编程语言是 C/C, 有个大佬要搞个"智能面包机",觉得用C来做太难了 于是就基于C搞了个简单点的语言,Java 就诞生了~~ 遗憾的是项目流产了,没做成…