基于卷积神经网络实现手写数字识别

news2025/1/13 10:20:57
基于卷积神经网络实现手写数字识别

基于卷积神经网络实现手写数字识别。具体过程如下:

(1) 定义ConvNet结构类及其前向传播方式

(2) 设置超参数以及导入相关的包。

(3) 定义训练网络函数和绘图函数,并在main函数中完成调用过程

程序
import os 
import numpy as np 
#from sklearn.datasets import fetch_openml # 引入openml数据源
from matplotlib import pyplot as plt # 引入绘图工具
import torch
from torchvision.datasets import mnist
#from mnist_models import AlexNet, ConvNet
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable


BASE_PATH = os.path.dirname(__file__)

# 设置模型超参数
EPOCHS = 50
SAVE_PATH = './models'

'''
# 载入MNIST数据集并显示部分样本
def load_mnist():
    # 从openml源载入MNIST数据集
    mnist = fetch_openml('mnist_784', version=1, data_home=os.path.join(BASE_PATH, './dataset'))
    X, y = mnist['data'], mnist['target']
    #X = mnist['data']#.astype(np.float32)
    #y = mnist['target']#.astype(np.int32)

    print('MNIST数据集大小:{}'.format(X.shape))

    # 显示其中25张样本图片
    for i in range(25):
        #print(i)
        digit = X.iloc[i * 2500]
        # 将图片恢复到28*28大小
        digit_image = digit.values.reshape(28, 28)
        
        # 绘制图片
        plt.subplot(5, 5, i + 1)
        # 隐藏坐标轴
        plt.axis('off')
        # 按灰度图绘制图片
        plt.imshow(digit_image, cmap='gray')
    # 显示图片
    plt.show()
    return X, y
'''

# 定义卷积网络结构
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, 5, 1, 1),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(10)
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(10, 20, 5, 1, 1),
            torch.nn.MaxPool2d(2),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(20)
        )
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(500, 60),
            torch.nn.Dropout(0.5),
            torch.nn.ReLU()
        )
        self.fc2 = torch.nn.Sequential(
            torch.nn.Linear(60, 20),
            torch.nn.Dropout(0.5),
            torch.nn.ReLU()
        )
        self.fc3 = torch.nn.Linear(20, 10)

    # 定义网络前向传播方式
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 500)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# 定义AlexNet结构
class AlexNet(torch.nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=1),
            torch.nn.Conv2d(64, 192, kernel_size=3, padding=2),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
            torch.nn.Conv2d(192, 384, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(384, 256, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(),
            torch.nn.Linear(256 * 6 * 6, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(),
            torch.nn.Linear(4096, 4096),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(4096, num_classes)
        )

    # 定义AlexNet前向传播过程
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x    

# 训练网络函数
def train_net(net, train_data, test_data):
    losses = []
    acces = []

    # 测试集上Loss变化情况
    eval_losses = []
    eval_acces = []
    # 损失函数设置为交叉熵函数
    criterion = torch.nn.CrossEntropyLoss()
    # 优化方法选用SGD,初始学习率为1e-2
    optimizer = torch.optim.SGD(net.parameters(), 1e-2)

    for e in range(EPOCHS):
        train_loss = 0
        train_acc = 0
        # 将网络设置为训练模型
        net.train()
        for image, label in train_data:
            image = Variable(image)
            label = Variable(label)
            # 前向传播
            out = net(image)
            loss = criterion(out, label)
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 记录误差
            train_loss += loss.data
            # 计算分类的准确率
            _, pred = out.max(1)
            num_correct = (np.array(pred, dtype=np.int32) == np.array(label, dtype=np.int32)).sum()
            acc = num_correct / image.shape[0]
            train_acc += acc

        train_loss_rate = train_loss / len(train_data)
        train_acc_rate = train_acc / len(train_data)
        losses.append(train_loss_rate)
        acces.append(train_acc_rate)

        # 在测试集上检验效果
        eval_loss = 0
        eval_acc = 0
        net.eval() # 将模型改为预测模式
        for image, label in test_data:
            image = Variable(image)
            label = Variable(label)
            out = net(image)
            loss = criterion(out, label)
            # 记录误差
            eval_loss += loss.data
            # 记录准确率
            _, pred = out.max(1)
            num_correct = (np.array(pred, dtype=np.int32) == np.array(label, dtype=np.int32)).sum()
            acc = num_correct / image.shape[0]
            eval_acc += acc

        eval_loss_rate = eval_loss / len(test_data)
        eval_acc_rate = eval_acc / len(test_data)
        eval_losses.append(eval_loss_rate)
        eval_acces.append(eval_acc_rate)
        print('epoch:{}, Train Loss: {:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}'.format(e, train_loss_rate, train_acc_rate, eval_loss_rate, eval_acc_rate))

        torch.save(net.state_dict(), os.path.join(BASE_PATH, SAVE_PATH, 'Alex_model_epoch' + str(e) + '.pkl'))

    return eval_losses, eval_acces
             
def draw_result(eval_losses, eval_acces):
    x = range(1, EPOCHS + 1)
    fig, left_axis = plt.subplots()
    p1, = left_axis.plot(x, eval_losses, 'ro-')
    right_axis = left_axis.twinx()
    p2, = right_axis.plot(x, eval_acces, 'bo-')
    plt.xticks(x, rotation=0)

    # 设置左坐标轴以及右坐标轴的范围、精度
    left_axis.set_ylim(0, 0.5)
    left_axis.set_yticks(np.arange(0, 0.5, 0.1))
    right_axis.set_ylim(0.9, 1.01)
    right_axis.set_yticks(np.arange(0.9, 1.01, 0.02))

    # 设置坐标及标题的大小、颜色
    left_axis.set_xlabel('Labels')
    left_axis.set_ylabel('Loss', color='r')
    left_axis.tick_params(axis='y', colors='r')
    right_axis.set_ylabel('Accuracy', color='b')
    right_axis.tick_params(axis='y', colors='b')
    plt.show()



if __name__ == '__main__':
    #x, y = load_mnist()

    print("基于卷积神经网络实现手写数字识别")

    train_set = mnist.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())//需要转化成tensor数据格式
    test_set = mnist.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())

    train_data = DataLoader(train_set, batch_size=64, shuffle=True)
    test_data = DataLoader(test_set, batch_size=64, shuffle=False)

    a, a_label = next(iter(train_data))
    #net = AlexNet()
    net = ConvNet()
    eval_losses, eval_acces = train_net(net, train_data, test_data)
    draw_result(eval_losses, eval_acces)

结果:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1531027.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

R语言实现多要素偏相关分析

偏相关分析是指当两个变量同时与第三个变量相关时,将第三个变量的影响剔除,只分析另外两个变量之间相关程度的过程,判定指标是相关系数的R值。 在GIS中,偏相关分析也十分常见,我们经常需要分析某一个指数与相关环境参…

基于Java中的SSM框架实现快餐店线上点餐系统项目【项目源码+论文说明】

基于Java中的SSM框架实现快餐店线上点餐系统演示 摘要 随着计算机互联网的高速发展。餐饮业的发展也加入了电子商务团队。各种网上点餐系统纷纷涌现,不仅增加了商户的销售量和营业额,而且为买家提供了极大的方便,足不出户,就能订…

Docker进阶教程 - 4 Docker网络

更好的阅读体验:点这里 ( www.doubibiji.com ) 4 Docker网络 先说我们现在遇到的问题: 我们现在有一个 Redis 容器,一个 SpringBoot 项目容器,在 SpringBoot 项目的代码中如何访问 Redis 容器中的服务呢…

Harbor镜像仓库的安装和使用

1 Harbor安装 参考文章: 银河麒麟v10离线安装harbor 由于配置了本地私有yum源,因此,直接使用yum命令安装docker和docker-compose 1.1 安装docker yum install docker-ce1.2 安装docker-compose yum install docker-compose1.3 安装harbo…

服务器被挖矿后修改密码报错Authentication token manipulation error

服务器被挖矿,需要修改密码,结果执行的时候发现报错 passwd: Authentication token manipulation error 尝试执行下列命令后再进行密码修改,修改成功 chattr -i /etc/passwd /etc/shadowchattr的主要用法 参考文章: https://c.biancheng.ne…

GEE遥感云大数据林业应用典型案例及GPT模型应用

近年来遥感技术得到了突飞猛进的发展,航天、航空、临近空间等多遥感平台不断增加,数据的空间、时间、光谱分辨率不断提高,数据量猛增,遥感数据已经越来越具有大数据特征。遥感大数据的出现为相关研究提供了前所未有的机遇&#xf…

威纶通触摸屏在编辑画面时如何更改窗口画面大小?

威纶通触摸屏在编辑画面时如何更改窗口画面大小? 如下图所示,Windows11系统下,打开威纶通触摸屏编程软件easy builder pro,此时可以看到画面窗口非常小,不方便编辑和操作, 如下图所示,点击上方工…

swagger3快速使用

目录 &#x1f37f;1.导入依赖 &#x1f32d;2.添加配置文件 &#x1f9c2;3.添加注解 &#x1f96f;4.访问客户端 1.导入依赖 引入swagger3的依赖包 <dependency><groupId>io.springfox</groupId><artifactId>springfox-boot-starter</artif…

B3870 [GESP202309 四级] 变长编码(膜拜版)

本题包括&#xff1a; 1.进制的超强使用 2.进制的截位使用 本题参考洛谷题解&#xff1a;https://www.luogu.com.cn/article/daqzhu5m &#xff08;在线膜拜作者的代码中&#xff09; 难度&#xff1a;普及- 对于笔者而言&#xff1a; 这道题在洛谷上通过率很高&#xff0c;…

“JavaScript: void(0)的替代方案有哪些?”

学习目标&#xff1a; 理解javascript:void(0)的工作原理&#xff0c;以及它在前端开发中的作用和用途。掌握javascript:void(0)的正确用法&#xff0c;包括在HTML中使用和在事件处理程序中使用。能够识别javascript:void(0)可能引起的常见问题&#xff0c;并学会相应的解决方…

理财第一课:炒股词典

文章目录 基础代码规则委比委差量比换手率市盈率市净率 短线操作散户亏钱的原因庄家分析炒股战法波浪理论其它 钱者&#xff0c;人生之大事&#xff0c;死生存亡之地&#xff0c;不可不察也。耕田之利&#xff0c;十倍&#xff1b;珠玉之赢&#xff0c;百倍&#xff1b;闹革命&…

安科瑞消防产品监控系统解决方案【电气火灾 消防设备 】

一、电气火灾监控系统 系统概述 l针对低压用电环节各回路中的剩余电流、温度和故障电弧等进行实时监测&#xff1b; l侧重点为低压用电环节的安全性&#xff0c;当剩余电流越限时报警输出&#xff0c;以提醒维护人员进行安全检查&#xff0c;防止因漏电引起的火灾发生&#…

【GameFramework框架内置模块】9、有限状态机(FSM)

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址 大家好&#xff0c;我是佛系工程师☆恬静的小魔龙☆&#xff0c;不定时更新Unity开发技巧&#xff0c;觉得有用记得一键三连哦。 一、前言 【GameFramework框架】系列教程目录&#xff1a; https://blog.csdn.net/q7…

从零开始写 Docker(七)---实现 mydocker commit 打包容器成镜像

本文为从零开始写 Docker 系列第七篇&#xff0c;实现类似 docker commit 的功能&#xff0c;把运行状态的容器存储成镜像保存下来。 完整代码见&#xff1a;https://github.com/lixd/mydocker 欢迎 Star 推荐阅读以下文章对 docker 基本实现有一个大致认识&#xff1a; 核心原…

解决jenkins运行磁盘满的问题

参考&#xff1a;https://blog.csdn.net/ouyang_peng/article/details/79225993 分配磁盘空间相关操作&#xff1a; https://cloud.tencent.com/developer/article/2230624 登录jenkins相对应的服务或容器中查看磁盘情况&#xff1a; df -h在102挂载服务器上看到是这两个文件…

OSPF特殊区域(stub\nssa)

stub区域——只有1类、2类、3类&#xff1b;完全stub区域——只有1类、2类 NSSA区域&#xff1a;本区域将自己引入的外部路由发布给其他区域&#xff0c;但不需要接收其他区域的路由 在NSSA区域的路由器上&#xff0c;引入外部路由时&#xff0c;不会转换成5类LSA&#xff0c…

Ethsign银河活动开启,简单参与领6个NFT

简介&#xff1a;EthSign是一个基于区块链技术的去中心化电子签名平台&#xff0c;目的是解决传统中心化电子签名服务的各种问题。用户可以使用钱包或社交媒体帐户生成的私钥签署文件和协议&#xff0c;数字签名记录在链上&#xff0c;文件经过加密存储在去中心化存储网络中&am…

CSS学习(3)-浮动和定位

一、浮动 1. 元素浮动后的特点 脱离文档流。不管浮动前是什么元素&#xff0c;浮动后&#xff1a;默认宽与高都是被内容撑开&#xff08;尽可能小&#xff09;&#xff0c;而且可以设置宽 高。不会独占一行&#xff0c;可以与其他元素共用一行。不会 margin 合并&#xff0c;…

DETR算法简介

DETR方法是一种使用了Transformer的端到端的目标检测方法&#xff0c;也是经典目标检测算法之一&#xff0c;本文将用最少的话&#xff0c;介绍DETR算法的大致思想。之前的方法或多或少的都不要添加一下额外的步骤&#xff0c;进行人为干预&#xff0c;即使是号称端到端的YOLO系…

Linux卸载Zabbix6 Agent v1 v2 简易操作手册

一、Zabbix6 卸载Zabbix Agent v1 要在Linux系统上卸载Zabbix Agent v1(zabbix_agent)&#xff0c;您可以使用包管理器执行此操作。以下是针对不同Linux发行版的卸载命令&#xff1a; # 对于基于Debian的系统&#xff08;如Ubuntu&#xff09;: sudo apt-get remove zabbix-ag…