深度学习手写字符识别:训练模型

news2025/1/2 3:39:59

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoader

from dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnx


parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')

args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:
    torch.backends.cudnn.benchmark = True


def train():
    os.makedirs('./output', exist_ok=True)
    if True: #not os.path.exists('output/total.txt'):
        ml.image_list(args.datapath, 'output/total.txt')
        ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')

    train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())
    val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())
    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)

    model = Net(10)
    #model = models.vgg16(num_classes=10)
    #model = models.resnet18(num_classes=10)  # 调用内置模型
    #model.load_state_dict(torch.load('./output/params_10.pth'))
    #from torchsummary import summary
    #summary(model, (3, 28, 28))

    if args.cuda:
        print('training with cuda')
        model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)
    loss_func = nn.CrossEntropyLoss()

    for epoch in range(args.epochs):
        # training-----------------------------------
        model.train()
        train_loss = 0
        train_acc = 0
        for batch, (batch_x, batch_y) in enumerate(train_loader):
            if args.cuda:
                batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())
            else:
                batch_x, batch_y = Variable(batch_x), Variable(batch_y)
            out = model(batch_x)  # 256x3x28x28  out 256x10
            loss = loss_func(out, batch_y)
            train_loss += loss.item()
            pred = torch.max(out, 1)[1]
            train_correct = (pred == batch_y).sum()
            train_acc += train_correct.item()
            print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'
                  % (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),
                     loss.item(), train_correct.item() / len(batch_x)))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()  # 更新learning rate
        print('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),
                                               train_acc / (len(train_data))))

        # evaluation--------------------------------
        model.eval()
        eval_loss = 0
        eval_acc = 0
        for batch_x, batch_y in val_loader:
            if args.cuda:
                batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())
            else:
                batch_x, batch_y = Variable(batch_x), Variable(batch_y)

            out = model(batch_x)
            loss = loss_func(out, batch_y)
            eval_loss += loss.item()
            pred = torch.max(out, 1)[1]
            num_correct = (pred == batch_y).sum()
            eval_acc += num_correct.item()
        print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),
                                             eval_acc / (len(val_data))))
        # 保存模型。每隔多少帧存模型,此处可修改------------
        if (epoch + 1) % 1 == 0:
            # torch.save(model, 'output/model_' + str(epoch+1) + '.pth')
            torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')
            #to_onnx(model, 3, 28, 28, 'params.onnx')

if __name__ == '__main__':
    train()

  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1430261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为机考入门python3--(6)牛客6-质数因子

分类:质数、素数 知识点: 取余符号% 5%3 2 取整符号// 5//3 1 list中int元素转str map(str, list) 题目来自【牛客】 def prime_factors(n): """ 输入一个正整数n,输出它的所有质因子(重复的也…

python pygame实现倒计时

实现思路 获取开始时间、当前时间,通过当前时间-开始时间时间差,再通过倒计时的总时长-时间差即可实现! 随着时间的流逝,当前时间会变大,也就导致时间差会变大,当使用总时长-时间差的时候,得到…

基于控制台的购书系统(Java 语言实现)

📚博客主页:爱敲代码的小杨. ✨专栏:《Java SE语法》|《数据结构与算法》 ❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️ 🙏小杨水平有限,欢…

【Nginx】nginx入门

文章目录 一、Web服务器二、Nginx三、Nginx的作用Web服务器正向代理反向代理 四、CentOS上安装Nginx(以CentOS 7.9为例) 一、Web服务器 Web 服务器,一般是指“网站服务器”,是指驻留于互联网上某种类型计算机的程序。Web 服务器可以向 Web 浏览器等客户…

[开源]基于野火指南者的MQTT框架+FreeRTOS移植(使用板载esp8266模块)

MQTT移植 项目地址 实际使用 基于野火指南者开发板 移植大佬的MQTT框架, 参考韦东山的课程 实际移植的项目: mqttclient 主要实现的是使用开发板上面的ESP8266连接MQTT服务器, 目前使用的是ip地址进行连接(FreeRTOS版本) 测试程序在mqtt/at/at_comment.c文件里面, 需要改…

Vivado MIG IP使用配置

目录 1 MIG 基本配置 1 MIG 基本配置 配置如下图所示 图1 图2 图3 图4 图5 图6 图8 图9 在设立只讲解共同配置,这是所有DDR3中配置通用部分。

2024年【低压电工】复审考试及低压电工作业考试题库

题库来源:安全生产模拟考试一点通公众号小程序 低压电工复审考试参考答案及低压电工考试试题解析是安全生产模拟考试一点通题库老师及低压电工操作证已考过的学员汇总,相对有效帮助低压电工作业考试题库学员顺利通过考试。 1、【单选题】()是保证电气作…

CSS常用动画网站(纯css echarts等 建议经常阅读 积累素材)

CSS动画代码集合 https://www.webhek.com/post/css3-animation-sniplet-collection/#/ 这个网站中将常见的css动画都进行了集合,并且有详细的代码,可以直接使用 echarts图表 https://www.isqqw.com/ echarts也是前端常用的,虽然官方文档已经给出了很多的案例,但是有时候产品还…

OpenHarmony—Gradle工程适配为Hvigor工程

该适配场景适用于开发者希望将原OpenHarmony API 7的工程升级到OpenHarmony API 9的工程。 原OpenHarmony API 7的工程采用的是Gradle自动化构建工具,而OpenHarmony API 9的工程则采用Hvigor自动化构建工具,由于Gradle构建工具和Hvigor构建工具的配置文…

【PyQt】03-信号与槽

文章目录 前言一、信号与槽二、Demo接收信号代码运行结果 自定义信号【重点】代码运行结果 总结 前言 我认为,信号就是触发,槽就是触发的事件。 一、信号与槽 信号(signal) 其实就是事件(按钮点击 、内容发生改变 、窗口的关闭事件&#xf…

无人机在消防领域的应用及其优势

无人机在消防领域的应用及其优势 随着科技的不断发展,无人机正以其独特的优势,逐渐在各个领域得到广泛应用。在消防领域中,无人机的应用也越来越受到关注。无人机在消防工作中具有独特的优势,可以帮助消防人员更好地开展任务&…

背景点击监督的时序动作定位 Background-Click Supervision for Temporal Action Localization

该论文介绍了 BackTAL,这是一种利用背景点击监督进行弱监督时序动作定位的新方法。 它将焦点从动作帧转移到背景帧,通过强调背景错误来改进定位。 BackTAL 包含分数分离模块和亲和力模块,增强了位置和特征建模。 Background-Click的说明 Click 点击级别监督的说明…

shell中脚本参数传递的两种方式

一、接在脚本后面输入参数值,多个参数间用空格隔开 采用$0,$1,$2..等方式获取脚本命令行传入的参数,值得注意的是,$0获取到的是脚本路径以及脚本名,后面按顺序获取参数,当参数超过10个时(包括10个),需要使用…

探索网络定位与连接:域名和端口的关键角色

目录 域名 域名的作用 域名的结构 域名的解析配置 父域名、子域名​编辑 https的作用 端口 图解端口 端口怎么用 判断网站是否存活 端口的作用 域名 域名是互联网上用于标识网站的一种易于记忆的地址。 域名是互联网基础架构的一个重要组成部分,它为网…

029 命令行传递参数

1.循环输出args字符串数组 public class D001 {public static void main(String[] args) {for (String arg : args) {System.out.println(arg);}} } 2.找打这个类的路径,打开cmd cmd C:\Users\Admin\IdeaProjects\JavaSE学习之路\scanner\src\com\yxm\demo 3. 编译…

JProfiler for Mac:提升性能和诊断问题的终极工具

在当今的高性能计算和多线程应用中,性能优化和问题诊断是至关重要的。JProfiler for Mac 是一个强大的性能分析工具,旨在帮助开发者更好地理解其应用程序的运行情况,提升性能并快速诊断问题。 JProfiler for Mac 的主要特点包括:…

Android Display显示框架整体流程

一.Android Display显示框架整体流程图

一篇文章了解系统眼中的键盘--以一个简单的系统分析从按键的输入到字符的显示

键盘输入 实现使用的设备 intel架构32位CPU, 思路为嵌入式系统工程师,使用的操作系统是《30天自制操作系统》里面的系统进行讲解 硬件实现 按键 使用单片机等的引脚可以获取电平状态从而获得按键的状态(单片机是一种集成到一块硅片上构成的一个小而完善的微型计算机系统, 用…

Linux系统卸载重装JDK

CentOS 系统是开发者常用的 Linux 操作系统,安装它时会默认安装自带的旧版本的 OpenJDK,但在开发者平时开发 Java 项目时还是需要完整的 JDK,所以我们部署 CentOS 开发环境时,需要先卸载系统自带的 OpenJDK,再重新安装…

《国色芳华》爆红网络,杨紫的“唐妆”惊艳四座。

♥ 为方便您进行讨论和分享,同时也为能带给您不一样的参与感。请您在阅读本文之前,点击一下“关注”,非常感谢您的支持! 文 |猴哥聊娱乐 编 辑|徐 婷 校 对|侯欢庭 在中国的电视剧市场近几年的趋势中,仙侠剧的热度逐…