深度学习:使用全连接神经网络FCN实现MNIST手写数字识别

news2025/1/19 2:31:09

1 引言

本项目构建了一个全连接神经网络(FCN),实现对MINST数据集手写数字的识别,没有借助任何深度学习算法库,从原理上理解手写数字识别的全过程,包括反向传播,梯度下降等。

2 全连接神经网络介绍

2.1 什么是全连接神经网络

全连接网络(Fully-Connected Network,简称FCN),即在多层神经网络中,第nN层的每个神经元都分别与第N-1层的神经元相互连接。如下图便是一个简单的全连接网络:

2.2 损失函数

损失函数(loss function)在深度学习领域是用来计算搭建模型预测的输出值和真实值之间的误差,是一种衡量模型与数据吻合程度的算法。损失函数的值越高预测就越错误,损失函数值越低则预测越接近真实值。对每个单独的观测(数据点)计算损失函数。将所有损失函数(loss function)的值取平均值的函数称为代价函数(cost function),更简单的理解就是损失函数是针对单个样本的,而代价函数是针对所有样本的。

  • 损失函数越小越好
  • 计算实际输出与目标之间的差距
  • 为更新输出提供依据(反向传播)

常见的损失函数

(1)均方误差损失(Mean Squared Error,MSE)

均方误差损失MSE,又称L2 Loss,用于计算模型输出y_hat 和目标值y 之差的均方差。一般用在线性回归中,可以理解为最小二乘法。均方差损失是机器学习、深度学习回归任务中最常用的一种损失函数

(2)平均绝对误差(Mean Absolute Error,MAE)

平均绝对误差MAE,又称L1 Loss,是另一种用于回归模型的损失函数。和 MSE 一样,这种度量方法也是在不考虑方向(如果考虑方向,那将被称为平均偏差(Mean Bias Error, MBE),它是残差或误差之和)的情况下衡量误差大小。但和 MSE 的不同之处在于,MAE 需要像线性规划这样更复杂的工具来计算梯度。此外,MAE 对异常值更加稳健,因为它不使用平方。损失范围也是 0 到 ∞。

(3)交叉熵损失函数(Cross Entropy Loss)

交叉熵(Cross Entropy)是Shannon信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,或者从压缩的角度来看,每个词平均要用几个位来编码。Cross Entropy损失函数是分类问题中最常见的损失函数。

2.3 反向传播

误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法由误差传播、参数更新两个环节循环迭代组成。

神经网络的训练过程中,前向传播和反向传播交替进行,前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新

 

3 使用FCN实现MNIST手写数字识别

3.1 MINIST数据集介绍

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。其中的图像的尺寸为28*28。采样数据显示如下:

3.2 FCN识别MINIST数据集代码实现

import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np

class MnistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            # 图片的原尺寸为28*28,转化为784,输入层为784,输出层为256
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = x.view(-1, 28*28*1)
        return self.layer(x)


batchsize = 32
lr = 0.01

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307, ), (0.3081,))])

data_train = datasets.MNIST(root="./data/", transform=transform, train=True, download=True)
data_test = datasets.MNIST(root="./data/", transform=transform, train=False)

train_loader = torch.utils.data.DataLoader(data_train, batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=batchsize, shuffle=False)

if __name__ == '__main__':
    model = MnistNet()

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.5)

    for i in range(5):
        plt.subplot(1, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(data_train.data[i], cmap=plt.cm.binary)
    plt.show()

    lepoch = []
    llost = []
    lacc = []
    epochs = 30
    for epoch in range(epochs):
        lost = 0
        count = 0
        for num, (x, y) in enumerate(train_loader, 1):
            y_h = model(x)
            loss = criterion(y_h, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lost += loss.item()
            count += batchsize
        print('epoch:', epoch + 1, 'loss:', lost / count, end=' ')
        lepoch.append(epoch + 1)
        llost.append(lost / count)

        with torch.no_grad():
            acc = 0
            count = 0
            for num, (x, y) in enumerate(test_loader, 1):
                y_h = model(x)
                _, y_h = torch.max(y_h.data, dim=1)
                acc += (y_h == y).sum().item()
                count += x.size(0)
            test_acc = acc / count * 100

        lacc.append(test_acc)
        print('acc:', test_acc)

    plt.plot(lepoch, llost, label='loss')
    plt.plot(lepoch, lacc, label='acc')
    plt.legend()
    plt.show()

3.3 结果输出

经过30个epoch后,在测试集上的准确率达到了97.3%

epoch: 1 loss: 0.0697015597740809 acc: 56.120000000000005
epoch: 2 loss: 0.0542279725531737 acc: 81.2
epoch: 3 loss: 0.051337766939401626 acc: 83.53
epoch: 4 loss: 0.05083678769866626 acc: 84.49
epoch: 5 loss: 0.05052243163983027 acc: 85.09
epoch: 6 loss: 0.05029139596422513 acc: 85.65
epoch: 7 loss: 0.050102355525890985 acc: 86.14
epoch: 8 loss: 0.04994755889574687 acc: 86.02
epoch: 9 loss: 0.0498184863169988 acc: 86.71
epoch: 10 loss: 0.04970114469528198 acc: 86.81
epoch: 11 loss: 0.04792855019172033 acc: 94.86
epoch: 12 loss: 0.047099880089362466 acc: 95.64
epoch: 13 loss: 0.04690476657748222 acc: 96.04
epoch: 14 loss: 0.04677621142864227 acc: 96.32
epoch: 15 loss: 0.046683601369460426 acc: 96.52
epoch: 16 loss: 0.04659009942809741 acc: 96.69
epoch: 17 loss: 0.04652327968676885 acc: 96.72
epoch: 18 loss: 0.04646410925189654 acc: 96.81
epoch: 19 loss: 0.0464125766257445 acc: 96.75
epoch: 20 loss: 0.04636456128358841 acc: 97.07000000000001
epoch: 21 loss: 0.046326734560728076 acc: 96.85000000000001
epoch: 22 loss: 0.04628034559885661 acc: 96.91
epoch: 23 loss: 0.04625135076443354 acc: 97.0
epoch: 24 loss: 0.046217381453514096 acc: 97.14
epoch: 25 loss: 0.046193461724122364 acc: 97.03
epoch: 26 loss: 0.046168098962306975 acc: 97.16
epoch: 27 loss: 0.0461397964378198 acc: 97.27
epoch: 28 loss: 0.0461252645790577 acc: 97.22
epoch: 29 loss: 0.04609716224273046 acc: 97.19
epoch: 30 loss: 0.04608173056840897 acc: 97.3

准确率变化曲线如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/821642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

android studio 找不到符号类 Canvas 或者 错误: 程序包java.awt不存在

android studio开发提示 解决办法是: import android.graphics.Canvas; import android.graphics.Color; 而不是 //import java.awt.Canvas; //import java.awt.Color;

uniapp android底部弹框

uniapp android底部弹框&#xff0c;带有动画效果 <view class"popup_box"><view class"bottom_more" click"handleClickCancel"><image src"/static/images/rescue/icon_more.png"></image></view><…

华为OD机试真题 JavaScript 实现【小朋友排队】【2023 B卷 100分】,附详细解题思路

目录 一、题目描述二、输入描述三、输出描述四、解题思路五、JavaScript算法源码六、效果展示1、输入2、输出 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 刷的越多&#xff0c;抽中的概率越大&#xff0c;每一题都有详细的答题思路、详细的代码注释、样例测试&am…

触发器实现海豚调度失败企业微信自动告警

原理 触发器监控工作流实例表&#xff0c;当工作流实例表中的状态更新后&#xff0c;针对状态为失败的任务进行企业微信告警。 发送企业微信消息函数 # 必须在pg的主机上线安装requests模块 pip install requests # 以postgres用户登陆psql客户端到etl数据库 psql etl -U po…

Modbus TCP转Profinet网关modbus tcp转以太网

大家好&#xff0c;今天我们来聊一聊如何使用捷米特的Profinet转modbusTCP协议转换网关在博图上进行非透传型配置。 1, 首先&#xff0c;我们需要安装捷米特JM-TCP-PN的GSD文件&#xff0c;并根据现场设备情况配置modbusTCP地址。然后&#xff0c;在博图中添加该GSD文件&#x…

Zoho CRM数据存储在哪里?如何保障数据安全?

随着互联网的发展&#xff0c;在线CRM逐渐成为企业管理客户关系&#xff0c;提高销售效率的首选。然而&#xff0c;很多企业对于在线CRM数据的存储方式并不了解&#xff0c;担心会有数据丢失和泄露的风险。那么&#xff0c;CRM数据存储在哪里&#xff1f;安全是否有保障&#x…

【Jquery大事件时间线】jquery实现大事件时间线(时间轴)的滚动切换效果『附完整源码』

文章目录 写在前面涉及知识点页面效果1、搭建框架1.1 模块搭建1.2 内容填充1.3 时间线的切换 2、完整代码2.1 html源码2.2 CSS源码2.3 js源码 3、完整源码包下载3.1百度网盘3.2 123云盘3.3邮箱留言 总结 写在前面 其实这种大事件记录的web页面也是我们常见的&#xff0c;尤其是…

检查 CPU 的上下文切换

一.什么是cpu上下文切换 CPU 上下文切换是操作系统在多任务环境下管理进程的一项关键任务。在现代计算机系统中&#xff0c;有多个进程同时运行&#xff0c;每个进程都需要一定的 CPU 时间来执行其任务。由于 CPU 在某一时刻只能执行一个进程的指令&#xff0c;因此操作系统需…

【【STM32学习-3】】

STM32学习-3 下面是对c语言的稍微复习 这个是我们设置好的文件 以后拖出去用就可以了 这里加入关于指针的感想 关于指针数组和数组指针的想法 常规的东西是int a10; int * p&a; &#xff08;p指向了a元素&#xff0c;意思是p等于a的地址 类型是int*&#xff09;就是 整型指…

“ \r “导致print打印被覆盖

这里写自定义目录标题 写在最前面1." \r " 回车符一些有趣的小功能倒计时加载中&#xff08;转圈&#xff09;进度条删除功能 强行不换行(1) python2中可以在print语句的末尾加上逗号&#xff08;2&#xff09;在python3里print是一个独立函数&#xff0c;可以通过修…

笔记02:CUDA编程模型

CUDA是一种通用的并行计算平台和编程模型&#xff0c;是在C语言基础上扩展的。 一、CUDA编程模型概述 1. CUDA编程结构 在一个异构环境中包含多个CPU和GPU&#xff0c;每个GPU和CPU的内存都由一条PCI-e总线分隔开&#xff0c;需要注意区分 &#xff08;1&#xff09;主机&a…

湖仓一体概念快问快答

概念篇 问题一 “湖仓一体”是什么&#xff1f; “湖仓一体”是一种新的架构模式&#xff0c;湖仓一体是将数据湖的灵活性和数仓的易用性、规范性、高性能结合起来的融合架构&#xff0c;无数据孤岛。湖仓一体数据存储在数据湖低成本的存储架构之上&am…

蓝桥云课ROS机器人旧版实验报告-07外设

项目名称 实验七 ROS[Kinetic/Melodic/Noetic]外设 成绩 内容&#xff1a;使用游戏手柄、使用RGBD传感器&#xff0c;ROS[Kinetic/Melodic/Noetic]摄像头驱动、ROS[Kinetic/Melodic/Noetic]与OpenCV库、标定摄像头、视觉里程计&#xff0c;点云库、可视化点云、滤波和缩…

嵌入式系统工程师怎样才能不落伍

不断增加的复杂性和异质化正在衍生出一些新的方法&#xff0c;能够避免在设计周期结束时出现意外。 在一个系统中&#xff0c;硬件的表现是否优秀取决于运行在其上的软件。随着系统复杂性的增加&#xff0c;总是软件在拖后腿。 缩小硬件和软件差距的方法是不断改进软件开发的方…

【Java】多医院、多诊所、多机构SaaS模式云HIS信息管理系统源码

云HIS&#xff0c;一款基于云计算和大数据技术的智慧医院云平台&#xff0c;为医疗机构提供了一种全新的信息化解决方案&#xff0c;旨在实现数据安全、用户满意度和成本效益的最佳平衡。 基于云计算技术的B/S架构的HIS系统&#xff0c;为基层医疗机构提供标准化的、信息化的、…

攻击数亿个账户,黑客利用OAuth2.0疯狂作恶

一、OAuth协议介绍 OAuth是一种标准授权协议&#xff0c;它允许用户在不需要向第三方网站或应用提供密码的情况下向第三方网站或应用授予对存储于其他网站或应用上的信息的 委托访问 权限。OAuth通过访问令牌来实现这一功能。 1.发展历史 OAuth协议始于2006年Twitter公司Ope…

Python爬虫遇到URL错误解决办法大全

在进行Python爬虫任务时&#xff0c;遇到URL错误是常见的问题之一。一个错误的URL链接可能导致爬虫无法访问所需的网页或资源。为了帮助您解决这个问题&#xff0c;本文将提供一些实用的解决方法&#xff0c;并给出相关代码示例&#xff0c;希望对您的爬虫任务有所帮助。 一、…

mysql进阶-触发器

在实际开发中&#xff0c;我们经常会遇到这样的情况&#xff1a;有 2 个或者多个相互关联的表&#xff0c;如 商品信息 和 库存信息 分别存放在 2 个不同的数据表中&#xff0c;我们在添加一条新商品记录的时候&#xff0c;为了保证数据的完整性&#xff0c;必须同时 在库存表中…

牛客网Verilog刷题——VL41

牛客网Verilog刷题——VL41 题目答案 题目 请设计一个可以实现任意小数分频的时钟分频器&#xff0c;比如说8.7分频的时钟信号&#xff0c;注意rst为低电平复位。提示&#xff1a;其实本质上是一个简单的数学问题&#xff0c;即如何使用最小公倍数得到时钟周期的分别频比。设小…

23种设计模式详解与示例代码(详解附DEMO)

设计模式在Java中的应用与实现 &#x1f680;&#x1f680;&#x1f680;1.创建型模式1. 工厂方法模式&#xff08;Factory Pattern&#xff09;2.抽象工厂模式&#xff08;Abstract Factory Pattern&#xff09;3. 单例模式&#xff08;Singleton Pattern&#xff09;4.原型模…