【深度学习入门项目】基于支持向量机的手写数字识别

news2024/9/21 16:34:34

目录

  • 导入必要的包
  • 1. 数据集
  • 2. 数据处理
  • 3. 训练过程
  • 4. 输出结果
  • 完整代码

本项目使用SVM训练模型,用于预测手写数字图片。

导入必要的包

numpy: 这个库是Python中常用的数学计算库。在这个项目中,我使用numpy来处理图像数据,将图像数据转换为一维向量,以便进行模型训练和测试。

matplotlib: 这个库是Python中常用的绘图库。在这个项目中,我使用matplotlib来显示一些手写数字图像样本以及测试样本和它们的预测结果。

sklearn: 这个库是Python中常用的机器学习库,提供了许多机器学习算法和工具。在这个项目中,使用sklearn来加载手写数字数据集、将数据集分为训练集和测试集、创建SVM分类器、进行模型训练和测试,并评估模型性能。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix

1. 数据集

使用Scikit-Learn库自带的手写数字数据集(digits dataset)。该数据集包含8x8像素的手写数字图像,共有10个类别(数字0到9),每个类别有约180个样本,总共大约有1797个样本。

并且通过以下代码展示数据集中的基本信息

# 加载手写数字数据集
digits = datasets.load_digits()

# 显示数据集基本信息
print("数据集基本信息:")
print("样本数量: {}".format(len(digits.images)))
print("图像大小: {}".format(digits.images[0].shape))

数据集基本信息:
样本数量: 1797
图像大小: (8, 8)\

# 显示一些样本图像
fig, axes = plt.subplots(4, 4, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(digits.target[i]),
            transform=ax.transAxes, color='green' if (digits.target[i]==digits.target[0]) else 'black')

在这里插入图片描述

2. 数据处理

由于数据为8×8的图像数据,因此每个图像包含64个像素值。如果不将图像数据转换为一维向量,算法无法处理直接矩阵的数据。我们需要将数据转化为一维向量,来让SVM能够处理数据。这个过程通过numpy库中的reshape()函数实现。

同时由于是机器学习项目,我们需要划分训练集和测试集,通过使用sklearn库中的train_test_split函数来随机划分数据集为训练集和测试集。具体地,我将手写数字数据集中的样本随机划分为训练集和测试集,其中训练集占总样本数的70%,测试集占30%。这个划分比例可以根据实际情况进行调整。

# 将图像数据转换为一维向量
X = digits.images.reshape((len(digits.images), -1))
y = digits.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

3. 训练过程

使用sklearn库中的SVC()函数创建SVM分类器,并指定超参数。在这个项目中,参数是gamma=0.001。在实际使用中,我们可以多次调整参数,结合损失函数的变化,以寻得最优的参数。或者也可以直接使用默认的参数设置,即clf = SVC。

训练模型:使用SVM分类器对训练集进行训练,通过调用fit()方法实现。

预测结果:使用训练好的SVM分类器对测试集进行预测,通过调用predict()方法实现。

# 创建SVM分类器
clf = SVC(gamma=0.001)

# 训练分类器
clf.fit(X_train, y_train)

# 测试分类器
y_pred = clf.predict(X_test)

4. 输出结果

首先是评估模型的性能,使用sklearn库中的confusion_matrix()函数计算模型的混淆矩阵,混淆矩阵的行表示实际标签,列表示预测标签,每个元素表示实际标签为该行所对应的数字,而分类器预测为该列所对应的数字的样本数。混淆矩阵可以比较直观的展示该项目中分类错误的个数,并根据混淆矩阵计算出模型的准确率、精确率、召回率和F1分数等指标。

# 输出混淆矩阵和准确率
cm = confusion_matrix(y_test, y_pred)
print("混淆矩阵:")
print(cm)

accuracy = clf.score(X_test, y_test)
print("准确率: {:.2f}%".format(accuracy * 100))

# 显示一些测试样本和其预测结果
fig, axes = plt.subplots(4, 4, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(X_test[i].reshape(8,8), cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(y_pred[i]),
            transform=ax.transAxes,
            color='green' if (y_pred[i]==y_test[i]) else 'red')

plt.show()

混淆矩阵:
[[56 0 0 0 0 0 0 0 0 0]
[ 0 57 0 0 0 0 0 0 0 0]
[ 0 0 44 0 0 0 0 0 0 0]
[ 0 0 0 60 0 0 0 0 0 0]
[ 0 0 0 0 70 0 0 0 0 0]
[ 0 0 0 0 0 56 1 0 0 0]
[ 0 0 0 0 0 0 48 0 0 0]
[ 0 0 0 0 0 0 0 55 0 1]
[ 0 1 0 0 0 0 0 0 49 0]
[ 0 0 0 0 0 0 0 1 0 41]]
准确率: 99.26%
在这里插入图片描述

完整代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix

# 加载手写数字数据集
digits = datasets.load_digits()

# 显示数据集基本信息
print("数据集基本信息:")
print("样本数量: {}".format(len(digits.images)))
print("图像大小: {}".format(digits.images[0].shape))

# 显示一些样本图像
fig, axes = plt.subplots(4, 4, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(digits.target[i]),
            transform=ax.transAxes, color='green' if (digits.target[i]==digits.target[0]) else 'black')

# 将图像数据转换为一维向量
X = digits.images.reshape((len(digits.images), -1))
y = digits.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# 创建SVM分类器
clf = SVC(gamma=0.001)

# 训练分类器
clf.fit(X_train, y_train)

# 测试分类器
y_pred = clf.predict(X_test)

# 输出混淆矩阵和准确率
cm = confusion_matrix(y_test, y_pred)
print("混淆矩阵:")
print(cm)

accuracy = clf.score(X_test, y_test)
print("准确率: {:.2f}%".format(accuracy * 100))

# 显示一些测试样本和其预测结果
fig, axes = plt.subplots(4, 4, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(X_test[i].reshape(8,8), cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(y_pred[i]),
            transform=ax.transAxes,
            color='green' if (y_pred[i]==y_test[i]) else 'red')

plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2055761.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FPGA开发——DS18B20读取温度并且在数码管上显示

一、简介 在上一篇文章中我们对于DS18B20的相关理论进行了详细的解释,同时也对怎样使用DS18B20进行了一个简单的叙述。在这篇文章我们通过工程来实现DS18B20的温度读取并且实现在数码管伤显示。 1、基本实现思路 根据不同时刻的操作,我们可以使用一个状…

基于vue框架的班级网站的设计与实现vg66m(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能:班级,学生,班级活动,班级相册,班级开支,活动记录 开题报告内容 基于Vue框架的班级网站设计与实现 开题报告 一、引言 随着互联网技术的飞速发展,网络已经成为人们日常生活中不可或缺的一部分。在教育领域,班级…

大白话解析:深入浅出大模型RAG模块全解析

文章目录 什么是 RAG? 技术交流&资料通俗易懂讲解大模型系列 RAG模块化 什么是模块化RAG? 索引模块 块优化 滑动窗口从小到大元数据附加 结构化组织 层次化索引知识图谱文档组织 预检索模块 查询扩展 多查询子查询CoVe 查询转换 重写HyDE 查询路由…

TON链上游戏项目开发基本要求及模式创建与海外宣发策略

TON(The Open Network)是由Telegram开发的区块链平台,以其高速、低延迟、和高扩展性吸引了大量开发者和项目方。TON链上游戏项目作为一个新兴领域,结合了区块链技术和游戏产业,为用户提供了全新的游戏体验和经济激励。…

精益生产咨询:为企业量身定制的高效能蜕变计划!——张驰咨询

在当今这个快速变化、竞争激烈的市场环境中,企业如何保持持续的竞争优势,提高生产效率,降低成本,同时又能快速响应市场需求,成为了每一个企业家必须面对的重大课题。精益生产(Lean Production)作…

第5节:Elasticsearch核心概念

我的后端学习笔记大纲 我的ElasticSearch学习大纲 1.Lucene和Elasticsearch的关系: 1.Lucene:最先进、功能最强大的搜索库,直接基于lucene开发,非常复杂,api复杂2.Elasticsearch:基于lucene,封装了许多luc…

跳槽?面试软件测试需要掌握的知识你Get了吗

想从事软件测试相关的工作,立志成为一名优秀的软件测试工程师。 一名优秀的软件测试工程师,需要扎实的专业基础,包括测试相关技术、编程技能、数据库知识、计算机网络、以及操作系统等等。对于没有测试经验的应届生求职者来说,面…

SpringBoot项目部署时application.yml文件的加载优先级和启动脚本

文章目录 application.yml文件的加载优先级(由高到低)第一级命令行参数第二级Jar包同级目录 /config第三级Jar包同级目录第四级classpath 下的/config第五级classpath 根路径/总结: logback.xml 文件加载顺序当application.yml 和 bootstrap.yml 同时存在时java jar…

淘宝天猫详情接口API:实现轻松购物,探索最具性价比的商品

随着电子商务的蓬勃发展,网络购物已经成为现代人日常生活中的重要部分。在这个浩瀚的电商海洋中,淘宝和天猫无疑是最为耀眼的两大平台。然而,如何在众多的商品中挑选出性价比最高的产品?淘宝天猫详情接口API为您提供了解决方案。 …

基于vue框架的班级管理系统3pdep(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能:学生,班级事务,班级,成绩信息,请假,销假,班级信息,教师 开题报告内容 基于Vue框架的班级管理系统 开题报告 一、引言 随着教育信息化进程的加快,学校管理工作逐渐从传统的纸质化、人工化向数字化、智能化转变。班级作为学…

Python与自动化测试:提高软件质量和稳定性

在软件开发过程中,自动化测试是提高软件质量和稳定性的重要手段之一。Python作为一种简洁而强大的编程语言,为自动化测试提供了丰富的工具和库。本文将介绍几个常见的自动化测试案例,并提供详细的Python代码示例,帮助您更好地理解…

前端面试——js作用域

说一说JS的作用域吧 作用域的分类 作用域分为:全局作用域,函数作用域,块级作用域 作用域的特性 全局作用域: 能够让变量和函数在全局位置访问,其挂载在浏览器的window对象下面 其中var定义的变量和function函数存…

怀旧风吹到体育圈,刘翔、郭晶晶等再翻红?明星与体育冠军代言的区别!

今年奥运,怀旧风吹到了体育圈,曾经的奥运冠军如刘翔、郭晶晶等再度成为公众焦点。这段时间,刘翔频频出现在伊利、霸王茶姬等品牌的广告中,还和法国球星姆巴佩合作拍摄了小红书广告。同样备受品牌关注的还有郭晶晶,巴黎…

【Python实现全屏播放视频】

效果如下: 虽然视频比较抽象,但是确实是用python(cv2)实现的 代码: import cv2 from playsound import playsound from threading import Threaddef func1():cap cv2.VideoCapture("mp4/out.mp4") #替换为视频路径ret, frame ca…

记一次长事务方法带来的坑

文章目录 1. 沟通需求2.分析需求3. 波折起4.初版完成5.锁等待超时6.消费者超时7.总结 1. 沟通需求 产品找到我说,咱要将一波数据给更新了,因为涉及业务,就不说具体的内容了,需要支持分页滚动,校对数据后进行推送&…

无人机系统的关键技术

一、飞控系统:是无人机完成整个飞行过程的关键,决定了无人机的飞行性能和稳定性。 二、导航系统:提供无人机所需的位置、速度和飞行姿态等信息,引导无人机按照指定航线飞行。 三、动力系统:提供飞行动力,…

报表工具是开源还是商用的好?如何选择适合自己的报表工具?

在当今数字化转型的浪潮中,制作既精确又直观的报表已成为个人高效工作与企业精准沟通的核心工具。然而,面对市场上纷繁复杂的报表工具选项,选择最适合自身或企业需求的那一款,宛如漫步于迷雾笼罩的森林,挑战重重&#…

React 学习——useMemo

useMemo使用场景&#xff1a;消耗非常大的计算&#xff0c;例如递归 import { useMemo, useState } from react; // 缓存&#xff1a;消耗非常大的计算&#xff0c;例如递归 function fib(n){console.log(fib);if(n < 3)return 1;return fib(n-2) fib(n-1); }const App (…

Python开发工具PyCharm v2024.2全新发布——新增Databricks集成

JetBrains PyCharm是一种Python IDE&#xff0c;其带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具。此外&#xff0c;该IDE提供了一些高级功能&#xff0c;以用于Django框架下的专业Web开发。 立即获取PyCharm v2024.2正式版(Q技术交流&#xff1a;786598704&…

Spark2.x 入门:DStream 转换操作

DStream转换操作包括无状态转换和有状态转换。 无状态转换&#xff1a;每个批次的处理不依赖于之前批次的数据。 有状态转换&#xff1a;当前批次的处理需要使用之前批次的数据或者中间结果。有状态转换包括基于滑动窗口的转换和追踪状态变化的转换(updateStateByKey)。 DStre…