MNIST手写数字识别

news2025/1/11 0:43:25

MNIST是一个手写体数字的图片数据集,该数据集由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,其包含 60,000 张训练图像和 10,000 张测试图像,每张图片的尺寸为 28 x 28
在这里插入图片描述

线性回归

我们尝试通过 线性回归模型 识别手写数字,输入的图片是 28 x 28像素,我们可以将其看为 784 个变量,即:
y = a 1 x 1 + a 2 x 2 + . . . + a n x n y = a_1x_1+a_2x_2+...+a_nx_n y=a1x1+a2x2+...+anxn

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf


# 计算梯度,并更新 [a1,a2,...,an],b 值
def gradient(a, b, x_array, y_array, learning_rate):
    a_gradient = tf.zeros([784, 1])
    b_gradient = 0
    length = len(x_array)
    # 计算梯度
    for i in range(0, length):
        x = x_array[i]
        y = y_array[i]
        base_gradient = (2 / length) * (np.dot(x, a) + b - y)[0]
        # print("base_gradient", base_gradient)
        a_gradient += base_gradient * tf.reshape(x, [784, 1])
        b_gradient += base_gradient

    # 更新 a、b 值
    new_a = a - learning_rate * a_gradient
    new_b = b - learning_rate * b_gradient
    return [new_a, new_b]


# 计算损失
def computer_loss(a, b, x_array, y_array):
    length = len(x_array)
    loss = 0
    # 计算梯度
    for i in range(0, length):
        x = x_array[i]
        y = y_array[i]
        loss += (np.dot(x, a) + b - y) ** 2
    loss /= length
    return loss


# 计算准确率
def computer_accuracy(a, b, x_array, y_array):
    accuracy = 0
    length = len(x_array)
    for i in range(0, length):
        x = x_array[i]
        y = np.dot(x, a) + b
        y = round(y[0])
        if y == y_array[i]:
            accuracy += 1
    return accuracy / length


mnist = tf.keras.datasets.mnist
(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_size = len(train_data)
test_size = len(test_data)

# 手写数字图片是 28x28 的 Tensor,需要将其转换为 1x784
train_data_reshape = tf.reshape(train_data, [train_size, 784])
# 将 int 转换为 float,否则会有计算问题
train_data_reshape = tf.cast(train_data_reshape, dtype=tf.dtypes.float32)
print("train_data_reshape shape", np.shape(train_data_reshape))
# 对数据进行归一化处理
train_data_reshape = train_data_reshape / tf.constant(255.0, shape=[train_size, 784])
print("train_data_reshape", train_data_reshape)

test_data_reshape = tf.reshape(test_data, [test_size, 784])
test_data_reshape = tf.cast(test_data_reshape, dtype=tf.dtypes.float32)
train_data_reshape = test_data_reshape / tf.constant(255.0, shape=[test_size, 784])

# 假设 y = a1x1 + a2x2 +...+ anxn +b 且 x shape [1,784],则 a shape 为 [784,1]
a = tf.random.normal([784, 1])
b = 0
loss_list = list()
accuracy_list = list()
for i in range(0, 1000):
    [a, b] = gradient(a, b, train_data_reshape, train_label, 0.01)
    if i % 10 == 0:
        loss = computer_loss(a, b, train_data_reshape, train_label)
        accuracy = computer_accuracy(a, b, test_data_reshape, test_label)
        print("loss = {} accuracy = {}".format(loss, accuracy))
        loss_list.append(loss)
        accuracy_list.append(accuracy)

print("a = {} b = {}".format(a, b))
l1 = plt.plot(loss_list, label="loss")
l2 = plt.plot(accuracy_list, label="accuracy")
plt.legend()
plt.show()

在这里插入图片描述
可以看出损失收敛在10左右,准确率只有15%左右,这是因为该模型存在两个问题:

  • 如果预测的数据是 2.5,那实际值应该是2还是3呢?所以应该通过概率来解决该问题,它需要输出多个结果,例如:1的概率为0.999,2的概率为0.0001,3的概率为0.0001等,最终所有结果的概率综合为1。我们称这样的问题为分类问题
  • 图片像素与数字并非线性关系,而是复杂的非线性关系

非线性分类

多输出问题

对于多个结果我们可以考虑使用矩阵的形式,例如 1x4 阶矩阵,需要输出 2 个结果,则可以进行如下运算:
[ a b c d ] ∗ [ 1 5 2 6 3 7 4 8 ] = [ 1 a + 2 b + 3 c + 4 d 5 a + 6 b + 7 c + 8 d ] {\begin{bmatrix} a&b&c&d\\ \end{bmatrix}} * {\begin{bmatrix} 1&5\\ 2&6\\ 3&7\\ 4&8\\ \end{bmatrix}} = {\begin{bmatrix} 1a+2b+3c+4d&5a+6b+7c+8d\\ \end{bmatrix}} [abcd] 12345678 =[1a+2b+3c+4d5a+6b+7c+8d]
手写数字需要10个结果,即 [10] 矩阵,每列的值代表数字 n 的概率,例如表示1的概率为0.999:
[ 0.999 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 0.0001 ] {\begin{bmatrix} 0.999 & 0.0001 & 0.0001& 0.0001& 0.0001& 0.0001& 0.0001& 0.0001& 0.0001& 0.0001 \end{bmatrix}} [0.9990.00010.00010.00010.00010.00010.00010.00010.00010.0001]
因为 x x x [ n , 784 ] [n,784] [n,784] 矩阵,所以应该给 x x x 点乘一个 [ 784 , 10 ] [784,10] [784,10] 矩阵,由此多个输出问题得以解决。

非线性问题

我们需要针对线性模型中增加非线性因子,使其变为非线性,这里采用ReLU函数:
在这里插入图片描述
y = r e l u ( a x + b ) y = relu(a x + b) y=relu(ax+b),其中 a = [ 784 , 10 ] a = [784,10] a=[784,10] y = [ 10 ] y = [10] y=[10]

# 线性回归模型识别手写数字

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

np.set_printoptions(edgeitems=10, linewidth=200)

mnist = tf.keras.datasets.mnist
(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_size = len(train_data)
test_size = len(test_data)

# 将 numpy.ndarray 类型数据转换为 Tensor
train_data = tf.convert_to_tensor(train_data, dtype=tf.dtypes.float32)
# 手写数字图片是 28x28 的 Tensor,需要将其转换为 1x784
train_data = tf.reshape(train_data, [train_size, 784])
# 对数据进行归一化处理
train_data = train_data / 255.
print("train_data", train_data)

# 对 label 进行 one_hot 编码
train_label = tf.convert_to_tensor(train_label, dtype=tf.dtypes.int8)
train_label = tf.one_hot(train_label, 10)

# 沿着第一个维度切片,将 train_data、train_label 转换为 tf.data.Dataset 对象,并按60个合为一个数据集
train_batch = tf.data.Dataset.from_tensor_slices((train_data, train_label)).batch(60)
print("train_batch", train_batch)

# 准备测试集数据
test_data = tf.convert_to_tensor(test_data, dtype=tf.dtypes.float32)
test_data = tf.reshape(test_data, [test_size, 784])
test_data = test_data / 255
test_label = tf.one_hot(test_label, 10)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation='relu')
])
optimizer = tf.optimizers.SGD(learning_rate=0.01)


def computer_acc():
    # 预测测试集结果
    test_out = model.predict(test_data)
    # 将概率最大置1,其他置0
    max_val = tf.reduce_max(test_out, axis=1)
    max_val = tf.reshape(max_val, [-1, 1])
    test_out = tf.where(tf.equal(test_out, max_val), tf.ones_like(test_out), tf.zeros_like(test_out))
    # 降维,判断整行数据是否相等
    acc = tf.reduce_all(tf.equal(test_out, test_label), axis=1)
    return tf.reduce_mean(tf.cast(acc, tf.float32))


loss_list = list()
acc_list = list()
for i in range(0, 1000):
    for (x, y) in train_batch:
        # -1 表示自动推断
        x = tf.reshape(x, (-1, 784))
        with tf.GradientTape() as tape:
            out = model(x)
            loss = tf.reduce_sum(tf.square(out - y) / x.shape[0])
        # 计算梯度
        gradient = tape.gradient(loss, model.trainable_variables)
        # 反向传递
        optimizer.apply_gradients(zip(gradient, model.trainable_variables))
    if i % 10 == 0:
        loss_list.append(loss)
        acc = computer_acc()
        acc_list.append(acc)
        print("i = {} loss = {} acc = {}".format(i, loss, acc))

l1 = plt.plot(loss_list, label="loss")
l2 = plt.plot(acc_list, label="acc")
plt.legend()
plt.show()

在这里插入图片描述
最终准确率收敛在了 84% 左右,原因是增加一个非线性因素可能不够,所以我们需要增加多个,使其可以拟合更复杂的非线性函数:
o u t 1 = r e l u ( a 1 x + b ) out_1 = relu(a_1 x + b) out1=relu(a1x+b),其中 a 1 = [ 784 , 512 ] a_1 = [784,512] a1=[784,512] o u t 1 = [ 512 ] out_1 = [512] out1=[512]
o u t 2 = r e l u ( a 2 o u t 1 + b ) out_2 = relu(a_2 out_1 + b) out2=relu(a2out1+b),其中 a 1 = [ 512 , 256 ] a_1 = [512,256] a1=[512,256] o u t 1 = [ 256 ] out_1 = [256] out1=[256]
o u t 3 = r e l u ( a 3 o u t 2 + b ) out_3 = relu(a_3 out_2 + b) out3=relu(a3out2+b),其中 a 1 = [ 256 , 10 ] a_1 = [256,10] a1=[256,10] o u t 3 = [ 10 ] out_3 = [10] out3=[10]

model = tf.keras.Sequential([
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10)
])

增加两层网络后,最终准确率收敛在了 98% 左右
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1053628.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

点亮一个LED+LED闪烁+LED流水灯——“51单片机”

各位CSDN的uu们好呀,这是小雅兰的最新专栏噢,最近小雅兰学习了51单片机的知识,所以就想迫不及待地分享出来呢!!!下面,让我们进入51单片机的世界吧!!! 点亮一个…

在线小说阅读系统

在线小说阅读系统: 功能实现 1.一级菜单:登录 注册 退出系统 2.二级菜单:查看小说列表 上传小说 下载小说 在线阅读 返回上级菜单 技术要点 1.面向对象思想 oop思想 2.TCP通信 :Socket通信(这里用TCP,…

【C++】:类和对象(1)

朋友们、伙计们,我们又见面了,本期来给大家解读一下有关C中类和对象的知识点,如果看完之后对你有一定的启发,那么请留下你的三连,祝大家心想事成! C 语 言 专 栏:C语言:从入门到精通…

Docker安装MS SQL Server并使用Navicat远程连接

思维导航 MS SQL Server简介 Microsoft SQL Server(简称SQL Server)是由微软公司开发的关系数据库管理系统,它是一个功能强大、性能卓越的企业级数据库平台,用于存储和处理大型数据集、支持高效查询和分析等操作。SQL Server 支持广泛的应用程序开发接口(API),包括 T-S…

vue ant 两个页面 调用同一个接口 想在 前端的一个 接口传 一个固定的值 ,另外一个不变 ,查询条件默认值加上自己要的就好啦

vue ant 两个页面 调用同一个接口 想在 前端的一个 接口传 一个固定的值 ,另外一个不变 查询条件默认值加上自己要的就好啦

【中秋国庆不断更】OpenHarmony多态样式stateStyles使用场景

Styles和Extend仅仅应用于静态页面的样式复用,stateStyles可以依据组件的内部状态的不同,快速设置不同样式。这就是我们本章要介绍的内容stateStyles(又称为:多态样式)。 概述 stateStyles是属性方法,可以根…

BUUCTF reverse wp 76 - 80

[CISCN2018]2ex 四处游走寻找关键代码 int __fastcall sub_400430(int a1, unsigned int a2, int a3) {unsigned int v3; // $v0int v4; // $v0int v5; // $v0int v6; // $v0unsigned int i; // [sp8h] [8h]unsigned int v9; // [sp8h] [8h]int v10; // [spCh] [Ch]v10 0;for…

【中秋国庆不断更】HarmonyOS对通知类消息的管理与发布通知(上)

一、通知概述 通知简介 应用可以通过通知接口发送通知消息,终端用户可以通过通知栏查看通知内容,也可以点击通知来打开应用。 通知常见的使用场景: 显示接收到的短消息、即时消息等。显示应用的推送消息,如广告、版本更新等。显示…

Godot Identifier “File“ not declared in the current scope.

解决方案: f FileAccess.open(savedir, FileAccess.READ)

牛客网_HJ1_字符串最后一个单词的长度

HJ1_字符串最后一个单词的长度 原题思路代码运行截图收获 原题 字符串最后一个单词的长度 思路 从最后一个字符开始遍历&#xff0c;遇到第一个空格时的长度即为最后一个单词的长度 代码 #include <iostream> #include <string> using namespace std;int main…

Purism 推出注重隐私的 Linux 平板电脑

导读一款昂贵的 Linux 平板电脑&#xff0c;注重安全和隐私。让我们拭目以待。 Purism 是一家日益流行的计算机硬件产品制造商&#xff0c;专门提供配备注重隐私的开源 Linux 发行版的笔记本电脑、台式机和移动设备。 最近&#xff0c;他们发布了一款新产品 Librem 11 平板电…

ssm+vue的图书馆书库管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频&#xff1a; ssmvue的图书馆书库管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;ssm vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构…

BUUCTF reverse wp 51 - 55

findKey shift f12 找到一个flag{}字符串, 定位到关键函数, F5无效, 大概率是有花指令, 读一下汇编 这里连续push两个byte_428C54很奇怪, nop掉下面那个, 再往上找到函数入口, p设置函数入口, 再F5 LRESULT __stdcall sub_401640(HWND hWndParent, UINT Msg, WPARAM wPara…

顺序表(7.24)

1.线性表 线性表 &#xff08; linear list &#xff09; 是 n 个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构&#xff0c;常见的线性表&#xff1a;顺序表、链表、栈、队列、字符串... 线性表在逻辑上是线性结构&#xff0c;也就说是连续的一…

Docker+K3S搭建集群

本次环境使用的是阿里云资源服务器&#xff0c;Linux版本为Centos&#xff0c;集群需要安装Docker和k3s。 Docker下载&#xff1a;Docker(一) 安装Docker_CV猿码人的博客-CSDN博客 K3S 下载&#xff1a;k3s在线快速安装部署-CSDN博客 一、定制镜像 制作Tomcat镜像&#xff0c…

GeoServer扩展功能之发布矢量瓦片

目录 前言 一、矢量瓦片 VS 栅格瓦片 1、基本对比 2、适量瓦片处理步骤 二、GeoServer矢量瓦片插件配置 1、确定GeoServer版本 2、查找瓦片切片插件 3、下载 并安装插件 三、GeoServer发布矢量瓦片 1、矢量瓦片处理 2、如何进行数据预览 总结 前言 今天是10月1日国庆…

常用数学分布

正态分布&#xff08;高斯分布&#xff09; 若随机变数 X X X 服从一个期望 μ \mu μ&#xff0c;标准差 的正态分布 σ \sigma σ&#xff0c;则记为 X ≈ N ( μ , σ 2 ) X \approx N(\mu,\sigma^2) X≈N(μ,σ2)&#xff0c;其密度函数为&#xff1a; f ( x ) 1 σ …

JAVA 获得特定格式时间

0 背景 我们有时要获取时间&#xff0c;年月日时分秒周几&#xff0c;有时要以特定的格式出现。这时就要借助 SimpleDateFormat 或者 DateTimeFormatter。有时要某个月份有多少天需要借助 Calendar。所以有必要了解一些知识。 1 SimpleDateFormat simpledateFormat 线程不安全…

侯捷 C++ STL标准库和泛型编程 —— 4 分配器 + 5 迭代器

4 分配器 4.1 测试 分配器都是与容器共同使用的&#xff0c;一般分配器参数用默认值即可 list<string, allocator<string>> c1;不建议直接用分配器分配空间&#xff0c;因为其需要在释放内存时也要指明大小 int* p; p allocator<int>().allocate(512,…

图像处理: ImageKit.NET 3.0.10704 Crack

关于 ImageKit.NET3 100% 原生 .NET 图像处理组件。 ImageKit.NET 可让您快速轻松地向 .NET 应用程序添加图像处理功能。从 TWAIN 扫描仪和数码相机检索图像&#xff1b;加载和保存多种格式的图像文件&#xff1b;对图像应用图像滤镜和变换&#xff1b;在显示屏、平移窗口或缩略…