模型训练识别手写数字(三)

news2024/10/27 16:47:26

 1. 使用卷积神经网络(CNN)来构建模型训练

import numpy as np
from keras import Sequential
from keras.api.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from keras.src.legacy.preprocessing.image import ImageDataGenerator
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# 加载数据
X = np.load("Data/dataset.npy", allow_pickle=True)
y = np.load("Data/class.npy", allow_pickle=True)

# 数据预处理:归一化并重新调整形状为 (样本数, 28, 28, 1) 用于 CNN
X = X.astype('float32') / 255.0  # 归一化
X = X.reshape(-1, 28, 28, 1)  # 调整形状

# One-hot 编码
onehot = OneHotEncoder(sparse_output=False)
y = onehot.fit_transform(y.reshape(-1, 1))

# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=14)

# 创建数据增强生成器
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=False,
    fill_mode='nearest'
)

# 构建 CNN 模型
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))  # 添加 Dropout 层以防止过拟合
model.add(Dense(10, activation='softmax'))  # 输出层

# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# 训练模型时使用数据增强
model.fit(datagen.flow(x_train, y_train, batch_size=32),
          epochs=100,
          validation_data=(x_test, y_test),
          verbose=1)

# 评估模型
predictions = model.predict(x_test)
predicted_classes = np.argmax(predictions, axis=1)
y_test_classes = np.argmax(y_test, axis=1)

# 计算 F-score
print("F-score: {0:.2f}".format(f1_score(y_test_classes, predicted_classes, average='micro')))

# 保存模型
model.save("my_model02.keras")

 2. 调用训练的模型进行测试

import cv2
import matplotlib.pyplot as plt
import numpy as np
from keras.api.models import load_model

# 加载模型
model = load_model("my_model02.keras")

# 加载手写数字图像
original_img = cv2.imread("Data/handwritten_digit.png", cv2.IMREAD_GRAYSCALE)

# 处理图像用于预测
img = cv2.resize(original_img, (28, 28))  # 调整为28x28大小
img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)[1]  # 二值化
img = img.astype('float32') / 255  # 归一化
img = img.reshape(1, 28, 28, 1)  # 调整形状为 (1, 28, 28, 1)

# 进行预测
predictions = model.predict(img)
predicted_class = np.argmax(predictions, axis=1)

# 可视化预测结果
plt.figure(figsize=(6, 6))

# 显示原图
plt.imshow(original_img, cmap='gray', aspect='equal')  # 使用原始图像
plt.title(f'Predicted: {predicted_class[0]}', fontsize=14)
plt.axis('off')

plt.tight_layout()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2224800.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Date工具类详细汇总-Date日期相关方法

# 1024程序员节 | 征文 # 目录 简介 Date工具类单元测试 Date工具类 简介 本文章是个人总结实际工作中常用到的Date工具类,主要包含Java-jdk8以下版本的Date相关使用方法,可以方便的在工作中灵活的应用,在个人工作期间频繁使用这些时间的格…

力扣 中等 740.删除并获得点数

文章目录 题目介绍题解 题目介绍 题解 由题意可知,在选择了数组中元素 a 后,该元素以及所有等于 a−1 和 a1 的元素都会从数组中删去,并获得 a 的点数。若还有多个值为 a的元素,由于所有等于 a−1 或 a1 的元素已经被删除&#x…

Linux相关概念和易错知识点(16)(Shell原理、进程属性和环境变量表的联系)

Shell原理及其模拟实现 在认识进程exec系列函数、命令行参数列表、环境变量之后,我们可以尝试理解一下Shell的原理,将各方知识串联起来,让Shell跑起来才能真正理解这些概念。我会以模拟Shell执行的原理模拟一个Shell。途中配上相关讲解。 1…

信息安全工程师(72)网络安全风险评估概述

前言 网络安全风险评估是一项重要的技术任务,它涉及对网络系统、信息系统和网络基础设施的全面评估,以确定存在的安全风险和威胁,并量化其潜在影响以及可能的发生频率。 一、定义与目的 网络安全风险评估是指对网络系统中存在的潜在威胁和风险…

《Python游戏编程入门》注-第3章3

《Python游戏编程入门》的“3.2.4 Mad Lib”中介绍了一个名为“Mad Lib”游戏的编写方法。 1 游戏玩法 “Mad Lib”游戏由玩家根据提示输入一些信息,例如男人姓名、女人姓名、喜欢的食物以及太空船的名字等。游戏根据玩家输入的信息编写出一个故事,如图…

基于SSM的汽车客运站管理系统【附源码】

基于SSM的汽车客运站管理系统(源码L文说明文档) 目录 4 系统设计 4.1 设计原则 4.2 功能结构设计 4.3 数据库设计 4.3.1 数据库概念设计 4.3.2 数据库物理设计 5 系统实现 5.1 管理员功能实现 5.1.1 管理员信息 5.1.2 车…

详细解读Movie Gen(2):个性化视频训练

Diffusion Models专栏文章汇总:入门与实战 前言:Meta最近重磅发布了视频生成30B的基础模型Movie Gen,长达93页的技术报告中干货满满,博主将详细解读Movie Gen的核心网络结构、个性化视频微调方法、视频编辑等方面。虽然大部分人没有直接预训练30B模型的机会,但是可以从中获…

C++游戏开发中的多线程处理是否真的能够显著提高游戏性能?如果多个线程同时访问同一资源,会发生什么?如何避免数据竞争?|多线程|游戏开发|性能优化

目录 1. 多线程处理的基本概念 1.1 多线程的定义 1.2 线程的创建与管理 2. 多线程在游戏开发中的应用 2.1 渲染与物理计算 3. 多线程处理的性能提升 3.1 性能评估 3.2 任务分配策略 4. 多线程中的数据竞争 4.1 数据竞争的定义 4.2 多线程访问同一资源的后果 4.3 避…

视频剪辑新手必备:四款热门电脑视频剪辑软件评测

现在真的是一个视频流量的时代,不得不说,我都已经开始刷视频小说了!如果你和我一样,是个对电脑视频剪辑充满好奇的新手,那么你一定想知道哪款软件最适合我们这些初学者。今天,我就来和大家分享一下我使用过…

gin入门教程(10):实现jwt认证

使用 github.com/golang-jwt/jwt 实现 JWT(JSON Web Token)可以有效地进行用户身份验证,这个功能往往在接口前后端分离的应用中经常用到。以下是一个基本的示例,演示如何在 Gin 框架中实现 JWT 认证。 目录结构 /hello-gin │ ├── cmd/ …

医院信息化与智能化系统(10)

医院信息化与智能化系统(10) 这里只描述对应过程,和可能遇到的问题及解决办法以及对应的参考链接,并不会直接每一步详细配置 如果你想通过文字描述或代码画流程图,可以试试PlantUML,告诉GPT你的文件结构,让他给你对应…

详解Pectra升级:如何影响以太坊价值及利益相关者

Pectra很可能是最后几个会直接影响用户和ETH持有者的升级之一。 原文:Galaxy Research;编译:Golem;编辑:郝方舟 出品 | Odaily星球日报(ID:o-daily) 编者按:以太坊 Pectr…

【SpringCloud】 K8s的滚动更新中明明已经下掉旧Pod,还是会把流量分到了不存活的节点

系列文章目录 文章目录 系列文章目录前言一、初步定位问题二、源码解释1.引入库核心问题代码进一步往下看【这块儿算是只是拓展了,问题其实处在上面的代码】Nacos是如何实现的? 如何解决 总结 前言 背景: 使用了SpringCloudGateWay 和 Sprin…

C++学习路线(二十五)

常见错误总结 错误1&#xff1a;对象const问题 #include <iostream>class Man { public:void walk() {std::cout << "I am walking." << std::endl;} };int main() {const Man man;man.walk();return 0; } 原因是Man man是const对象 但是调用了…

大语言模型的Scaling Law【Power Low】

NLP-大语言模型学习系列目录 一、注意力机制基础——RNN,Seq2Seq等基础知识 二、注意力机制【Self-Attention,自注意力模型】 三、Transformer图文详解【Attention is all you need】 四、大语言模型的Scaling Law【Power Low】 文章目录 NLP-大语言模型学习系列目录一、什么是…

Stable Diffusion视频插件Ebsynth Utility安装方法

一、Ebsynth Utility制作视频的优势&#xff1a; 相比其他视频制作插件&#xff0c;Ebsynth Utility生成的视频&#xff0c;画面顺滑无闪烁&#xff0c;对显存要求相对不高。渲染速度也还可以接受。其基本过程为&#xff1a; 1、将参考视频分解为单个帧&#xff0c;并同时生成…

模型训练识别手写数字(二)

模型训练识别手写数字&#xff08;一&#xff09;使用手写数字图像进行模型测试 一、生成手写数字图像 1. 导入所需库 import cv2 import numpy as np import oscv2用于计算机视觉操作。 numpy用于处理数组和图像数据。 os用于文件和目录操作。 2. 初始化画布 canvas np.z…

GitHub下载参考

1.Git下载 Git下载https://blog.csdn.net/mengxiang_/article/details/128193219 注意&#xff1a;根据电脑的系统配置选择合适的版本&#xff0c;我安装的是64.exe的版本 2.Git右键不出现问题&#xff1a; Git右键不出现https://blog.csdn.net/ling1998/article/details/1…

Java项目实战II基于微信小程序的马拉松报名系统(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 马拉松运动…

[SWPUCTF 2022 新生赛]py1的write up

开启靶场&#xff0c;下载附件&#xff0c;解压后得到&#xff1a; 双击exe文件&#xff0c;出现弹窗&#xff1a; 问的是异或&#xff0c;写个python文件来计算结果&#xff1a; # 获取用户输入的两个整数 num1 int(input("Enter the first number: ")) num2 int…