002-基于Pytorch的手写汉字数字分类

news2024/9/20 15:07:17

本节将介绍一种

2.1 准备

2.1.1 数据集

(1)MNIST

只要学习过深度学习相关理论的人,都一定听说过名字叫做LeNet-5模型,它是深度学习三巨头只有Yann Lecun在1998年提出的一个CNN模型(很多人认为这是第一个具有实际应用价值的CNN模型)。在当年使用该模型可以很好地完成手写体数字的识别,而该模型所处理的手写体数字数据库称为MNIST。

MNIST全称是:Mixed National Institute of Standards and Technology databas,它包含70000张手写数字的灰度图片,每一张图片包含 28 X 28 个像素点。数据集被分为两部分,其中训练(mnist.train)集包括60000样本,测试集(mnist.test)包含10000样本。训练集又进一步封你为 55000 个样本用于训练,5000样本用于验证。下图是MNIST样本实例图。

MNIST数据集虽然经典,但也有问题。最主要的问题是,它太简单了!相对于现在动辄上百万个参数的“大”模型,MNIST数据集要小很多,且只是简单的十类问题,因此导致现有的模型在MNIST上的分类精度都超过了95%。为了更直观地观察不同算法间的性能差异,需要用一个更复杂一点的数据集,这时Fashion-MNIST出现了。

(2)Fashion-MNIST

FashionMNIST是一个替代MNIST的图像数据集。 它是由一家德国科技公司(Zalando)整理提供。FashionMNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。因此,能跑MNIST数据集的代码,只需稍加改动,就可以跑新的数据集。两个数据集的不同之处主要有两点,一是虽然两者都是以灰度图像呈现的,但MNIST呈现的是数字,背景设为0,前景设为1,FashionMNIST则是真正意义的灰度数据集。二是两者内容不同,前者被分类的是手写体数字,后者则是十类衣物服饰(分别是:T恤、裤子、套头衫、连衣裙、大衣、凉鞋、衬衫、运动鞋、包、短靴),其内容的复杂程度远高于手写体数字。下图是FashionMNIST的一个图示。

网上有很多基于FashionMNIST数据集的实例,在此就不再重复介绍。

本节实例选用的是中国版的MNIST,由英国纽卡斯尔大学整理并提供,我们不妨将其称为CHN-MNIST数据集。

(3)CHN-MNIST

该数据集共由100人书写,每人重复书写10遍,因此数据集样本数为1000组,每组包括15个汉字的数字,即“零、一、二、三、四、五、六、七、八、九、十、百、千、万、亿”,总样本数为15000。图像的分辨率为300*300。

2.1.2  模型

对于这样一个简单的分类任务,不需要使用太复杂的网络,前面提到的LeNet-5足能胜任。

对于LeNet-5网络模型的介绍,网上一搜一大把,在此不再赘述,只贴出该模型的示意图,供大家参考。

2.2 代码解析

下面将结合代码,一部分一部分的介绍具体的过程。

(1)载入必要的扩展库

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

由于是第一个例程,我们对所使用的扩展库详细加以介绍:

  • matplotlib库:用于绘图
  • numpy库:用于数值计算
  • pandas库:用于数据分析
  • torch库:提供Pytorch支持
  • PIL库:用于图像绘制
  • tqdm库:Python提供的进度条空间库

(2)设置参数

这一部分完成的是设置一些与模型训练有关的超参数。如下面代码所示:

batch_size = 32  # 批次大小
lr = 0.003  # 学习率
epochs = 10  # 迭代轮数
save_path = './best_model.pkl'  # 模型保存路径
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 设备

各个参数的功能见注释,至于各个参数数值大小对最终结果的影响,将放在后续的章节介绍。其中最后一行是自动检测是否安装了cuda,如果是,则启动gpu加速。

(3)加载数据集

这一部分完成的是设置一些与模型训练有关的超参数。如下面代码所示:

class CustomDataset(Dataset):
    def __init__(self, k, l, csv_file='./chinese_mnist.csv'):
        self.df = pd.read_csv(csv_file)

        self.k = {'九': int(9), '十': int(10), '百': int(11), '千': int(12), '万': int(13), '亿': int(14), '零': int(0),
                  '一': int(1), '二': int(2), '三': int(3), '四': int(4), '五': int(5), '六': int(6), '七': int(7),
                  '八': int(8)}

        self.target = 'character'
        self.features = ['suite_id', 'sample_id', 'code', ]

        self.labels = np.asarray(self.df.iloc[:, 4])

        self.y = df[self.target]
        self.X = df.drop(self.target, axis=1)

    def __getitem__(self, idx):
        single_image_label = self.labels[idx]

        class_id = self.k[single_image_label]

        img = Image.open(f"./data/data/input_{self.X.iloc[idx, 0]}_{self.X.iloc[idx, 1]}_{self.X.iloc[idx, 2]}.jpg")
        img = np.array(img)

        return img, class_id

    def __len__(self):
        return len(self.X)

还需要对数据集进行一下预处理,便于后面的训练过程g

# 1.构建索引到汉字的映射字典
num2char = {int(9): '九', int(10): '十', int(11): '百',
            int(12): '千', int(13): '万', int(14): '亿',
            int(0): '零', int(1): '一', int(2): '二',
            int(3): '三', int(4): '四', int(5): '五',
            int(6): '六', int(7): '七', int(8): '八'}

# 2.读取csv处理文件
df = pd.read_csv('./chinese_mnist.csv', sep=',')

# 3.处理数据
train_df = df.groupby('value').apply(lambda x: x.sample(700, random_state=42)).reset_index(drop=True)
x_train, y_train = train_df.iloc[:, :-2], train_df.iloc[:, -2]

test_df = df.groupby('value').apply(lambda x: x.sample(300, random_state=42)).reset_index(drop=True)
x_test, y_test = test_df.iloc[:, :-2], test_df.iloc[:, -2]

(未完,待续)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1553383.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

npm install 报错ERESOLVE unable to resolve dependency tree

描述:npm install 报错ERESOLVE unable to resolve dependency tree 解决方案: npm install --legacy-peer-deps

【Linux】网络编程套接字二

网络编程套接字二 1.TCP网络编程1.1TCP Server服务端1.2 TCP Client客户端 2.Server 多进程版本2.1普通版2.2 信号版 3.Server 多线程版4.Server 线程池版5.日志函数重新设计6.守护进程7.TCP协议通讯流程8.TCP和UDP 对比 喜欢的点赞,收藏,关注一下把&…

[Java基础揉碎]抽象类

目录 通过问题引出 介绍 关键点 细节 ​编辑 抽象类的最佳设计模式--模版设计模式 1.先用最容易想到的方法 2.分析问题,提出使用模板设计模式 通过问题引出 假如我们有个动物类, 动物都有eat吃的方法, 但是具体吃什么, 我们不知道, 因为是什么动物我们不知道…

绘制特征曲线-ROC(Machine Learning 研习十七)

接收者操作特征曲线(ROC)是二元分类器的另一个常用工具。它与精确度/召回率曲线非常相似,但 ROC 曲线不是绘制精确度与召回率的关系曲线,而是绘制真阳性率(召回率的另一个名称)与假阳性率(FPR&a…

【爬虫框架pyspider】01-pyspider入门与基本使用

前言 前面我们把爬虫的流程实现一遍,将不同的功能定义成不同的方法,甚至抽象出模块的概念。如微信公众号爬虫,我们已经有了爬虫框架的雏形,如调度器、队列、请求对象等,但是它的架构和模块还是太简单,远远…

|行业洞察·碳纤维|《中国碳纤维行业现状与发展趋势-39页》

报告内容的详细解读: 1. 战略性新材料的重要性 碳纤维是一种轻质高强的高性能纤维材料,在航空航天、国防军工、高端装备制造等领域具有不可替代的作用。碳纤维的应用有助于减少能源消耗和降低碳排放,符合全球可持续发展的要求。 |趋势洞察…

Java增强for循环和foreach循环的误区

网上很多文章都在说增强for循环和foreach循环遍历时不能修改值&#xff0c;只能查看&#xff0c;其实是有区分条件的&#xff0c;不能修改值的是包装类&#xff0c;例如List<String>,引用类型是可以修改值的&#xff0c;例如对象集合。 使用增强for循环或者foreach循环遍…

李宏毅【生成式AI导论 2024】第6讲 大型语言模型修炼_第一阶段_ 自我学习累积实力

背景知识:机器怎么学会做文字接龙 详见:https://blog.csdn.net/qq_26557761/article/details/136986922?spm=1001.2014.3001.5501 在语言模型的修炼中,我们需要训练资料来找出数十亿个未知参数,这个过程叫做训练或学习。找到参数后,我们可以使用函数来进行文字接龙,拿…

解决“Pycharm中Matplotlib图像不弹出独立的显示窗口”问题

matplotlib的绘图的结果默认显示在SciView窗口中, 而不是弹出独立的窗口, 这样看起来就不是很舒服&#xff0c;不习惯。 通过修改设置&#xff0c;改成独立弹出的窗口。 File—>Settings—>Tools—>Python Scientific—>Show plots in toolwindow 将√去掉即可

一台日本原生ip站群服务器多少钱?

一台日本原生ip站群服务器多少钱&#xff1f;日本原生ip站群服务器的价格受到多个因素的影响。以下是一些主要的因素&#xff1a; 服务器配置&#xff1a;硬件配置越高&#xff0c;自然价格也越高。对于站群服务器来说&#xff0c;由于需要同时运行多个网站&#xff0c;因此配置…

Vue挂载全局方法

简介&#xff1a;有时候&#xff0c;频繁调用的函数&#xff0c;我们需要把它挂载在全局的vue原型上&#xff0c;方便调用&#xff0c;具体怎么操作&#xff0c;这里来记录一下。 一、这里以本地存储的方法为例 var localStorage window.localStorage; const db {/** * 更新…

学习JavaEE的日子 Day32 线程池

Day32 线程池 1.引入 一个线程完成一项任务所需时间为&#xff1a; 创建线程时间 - Time1线程中执行任务的时间 - Time2销毁线程时间 - Time3 2.为什么需要线程池(重要) 线程池技术正是关注如何缩短或调整Time1和Time3的时间&#xff0c;从而提高程序的性能。项目中可以把Time…

【tensorflow框架神经网络实现鸢尾花分类】

文章目录 1、数据获取2、数据集构建3、模型的训练验证可视化训练过程 1、数据获取 从sklearn中获取鸢尾花数据&#xff0c;并合并处理 from sklearn.datasets import load_iris import pandas as pdx_data load_iris().data y_data load_iris().targetx_data pd.DataFrame…

Flask学习(六):蓝图(Blueprint)

蓝图&#xff08;Blueprint&#xff09;&#xff1a;将各个业务进行区分&#xff0c;然后每一个业务单元可以独立维护&#xff0c;Blueprint可以单独具有自己的模板、静态文件或者其它的通用操作方法&#xff0c;它并不是必须要实现应用的视图和函数的。 Demo目录结构&#xf…

八大技术趋势案例(人工智能物联网)

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

利用Java代码混淆技术提升应用程序抗逆向工程能力

摘要 本文探讨了代码混淆在保护Java代码安全性和知识产权方面的重要意义。通过混淆技术&#xff0c;可以有效防止代码被反编译、逆向工程或恶意篡改&#xff0c;提高代码的安全性。常见的Java代码混淆工具如IPAGuard、Allatori、DashO、Zelix KlassMaster和yGuard等&#xff0…

Python人工智能:气象数据可视化的新工具

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;在数据处理、科学计算、数学建模、数据挖掘和数据可视化方面具备优异的性能&#xff0c;这些优势使得Python在气象、海洋、地理、气候、水文和生态等地学领域的科研和工程项目中得到广泛应用。可以…

物联网实战--入门篇之(一)物联网概述

目录 一、前言 二、知识梳理 三、项目体验 四、项目分解 一、前言 近几年很多学校开设了物联网专业&#xff0c;但是确却地讲&#xff0c;物联网属于一个领域&#xff0c;包含了很多的专业或者说技能树&#xff0c;例如计算机、电子设计、传感器、单片机、网…

葵花卫星影像应用场景及数据获取

一、卫星参数 葵花卫星是由中国航天科技集团公司研制的一颗光学遥感卫星&#xff0c;代号CAS-03。该卫星于2016年11月9日成功发射&#xff0c;位于地球同步轨道&#xff0c;轨道高度约为35786公里&#xff0c;倾角为0。卫星设计寿命为5年&#xff0c;搭载了高分辨率光学相机和多…

Oracle存数字精度问题number、binary_double、binary_float类型

--表1 score是number(10,5)类型 create table TEST1 (score number(10,5) ); --表2 score是binary_double类型 create table TEST2 (score binary_double ); --表3 score是binary_float类型 create table TEST3 (score binary_float );实验一&#xff1a;分别往三张表插入 小数…