深度学习神经网络实战:多层感知机,手写数字识别

news2024/9/22 22:37:03

目的

利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别

设计

简单三层神经网络

  • 输入层
    28*28个神经原,代表每一张手写数字图片的灰度
  • 隐藏层
    100个神经原
  • 输出层
    -10个神经原,分别代表10个数字

代码

// 导入 TensorFlow.js 库
import tf from "@tensorflow/tfjs";
import * as tfjsnode from "@tensorflow/tfjs-node";
import * as tfvis from "@tensorflow/tfjs-vis";
import fs from "fs";
import plot from "nodeplotlib";
// 定义模型
const model = tf.sequential();

// 添加输入层
model.add(
  tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" })
);

// 添加隐藏层
model.add(tf.layers.dense({ units: 100, activation: "relu" }));

// 添加输出层
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));

// 编译模型
model.compile({
  optimizer: "sgd",
  loss: "categoricalCrossentropy",
  metrics: ["accuracy"],
});
const trainDataLen = 3000;
const testDataLen = 2000;

// 加载 MNIST 数据集
import pkg from "mnist";
const { set: Dataset } = pkg;
const set = Dataset(trainDataLen, testDataLen);
const trainingSet = set.training;
const testSet = set.test;

const trainXs = [];
const testXs = [];

const trainLabels = [];
const testLabels = [];

for (let i = 0; i < trainingSet.length; i++) {
  trainXs.push(trainingSet[i].input);
  trainLabels.push(trainingSet[i].output.indexOf(1));
}

for (let i = 0; i < testSet.length; i++) {
  testXs.push(testSet[i].input);
  testLabels.push(testSet[i].output.indexOf(1));
}

// 准备数据
const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]);
const trainYsOneHot = tf.oneHot(trainLabels, 10);

//记录每轮模型训练中的损失和精度,为了绘制曲线图
var accPlot = [];
var lossPlot = [];

// 模型训练
model
  .fit(trainXsTensor, trainYsOneHot, {
    batchSize: 64,
    epochs: 100,
    validationSplit: 0.2,
    callbacks: {
      onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`),
      onEpochEnd: async (epoch, logs) => {
        console.log(
          `Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed(
            3
          )}, Accuracy: ${logs.acc.toFixed(3)}`
        );
        //记录loss和acc,绘制曲线图
        accPlot.push(logs.acc.toFixed(3));
        lossPlot.push(logs.loss.toFixed(3));

        await tf.nextFrame(); // 防止阻塞
      },
      onBatchEnd: async (batch, logs) => {
        console.log(
          `Batch ${batch} completed. Loss: ${logs.loss.toFixed(
            3
          )}, Accuracy: ${logs.acc.toFixed(3)}`
        );
        await tf.nextFrame(); // 防止阻塞
      },
    },
  })
  .then((history) => {
    console.log("Training completed!", history);
    //绘制模型训练过程中的损失函数和模型精度曲线变化
    const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1);
    plot.plot(
      [
        { x: epochs, y: lossPlot, name: "Loss" },
        { x: epochs, y: accPlot, name: "Accuracy" },
      ],
      {
        filename: "loss_acc.png",
      }
    );

    //模型评估
    const testXsTensor = tf.tensor(testXs, [testDataLen, 784]);
    const testYsOneHot = tf.oneHot(testLabels, 10);

    const result = model.evaluate(testXsTensor, testYsOneHot);
    const testLoss = result[0].dataSync()[0];
    const testAccuracy = result[1].dataSync()[0];

    console.log(`Test loss: ${testLoss.toFixed(3)}`);
    console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`);
    //保存模型
    model.save("file://./my-model").then(() => {
      console.log("Model saved!");
    });
  });

package.json

{
  "name": "neural_network",
  "version": "1.0.0",
  "description": "",
  "type": "module",
  "main": "mlpTest.js",
  "scripts": {
    "test": "echo \"Error: no test specified\" && exit 1",
  },
  "author": "",
  "license": "ISC",
  "dependencies": {
    "@tensorflow/tfjs": "^4.17.0",
    "@tensorflow/tfjs-node": "^4.17.0",
    "@tensorflow/tfjs-vis": "^1.0.0",
    "mnist": "^1.1.0",
    "nodeplotlib": "^0.7.7"
  },
  "devDependencies": {
    "@babel/core": "^7.0.0",
    "@babel/preset-env": "^7.0.0",
    "babel-loader": "^8.0.0",
    "webpack": "^5.0.0",
    "webpack-cli": "^4.0.0"
  }
}

模型结果

模型精度

损失函数与模型精度变化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1464080.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL Server —— While语句循环

一&#xff1a;简介 while 循环是有条件的循环控制语句。满足条件后&#xff0c;再执行循环体中的SQL语句。 while: break, 如果有多条语句可以在while后面添加begin-end。关于while的语法 while(条件) -- begin -- 语句1 -- 语句2 -- break 根据情况是否添加break -- end 二…

LeetCode LCR 055.二叉搜索树迭代器

实现一个二叉搜索树迭代器类BSTIterator &#xff0c;表示一个按中序遍历二叉搜索树&#xff08;BST&#xff09;的迭代器&#xff1a; BSTIterator(TreeNode root) 初始化 BSTIterator 类的一个对象。BST 的根节点 root 会作为构造函数的一部分给出。指针应初始化为一个不存在…

音视频技术-声反馈啸叫的产生与消除

目录 1.均衡调节: 2.移频法: 3.移相法: 4.比较法: 在扩音系统中,产生啸叫危害很大,一方面影响会议、演出等活动的正常进行,另一方面严重的啸叫会导致音响设备的损坏。 “啸叫”是“声反馈”的俗称,形成的机制复杂,消除的手段多样,专业调音师也对

RSA之前端加密后端解密

RSA之前端加密后端解密 RSA加密解密方式有&#xff1a; &#xff08;1&#xff09;公钥加密&#xff0c;私钥解密&#xff1b; &#xff08;2&#xff09;私钥加密&#xff0c;公钥解密&#xff1b; 此文章中以下我使用的是前端公钥加密&#xff0c;后端私钥解密&#xff1b; …

Springboot集成Druid实现监控功能

Druid是阿里巴巴开发的号称为监控而生的数据库连接池&#xff0c;在功能、性能、扩展性方面&#xff0c;都超过其他数据库连接池&#xff0c;包括DBCP、C3P0、BoneCP、Proxool、JBoss DataSource等等等&#xff0c;秒杀一切。Druid可以很好的监控DB池连接和SQL的执行情况&#…

信息安全工程师 软考回顾(一)

&#x1f433;概述 图源&#xff1a;文心一言 信息安全证书已经考了一年有余&#xff0c;尽管我目前没有从业安全的打算&#xff0c;况且自己的实践能力与从业标准依然有所差距&#xff0c;但其中的内容也值得再温习一遍~&#x1f95d;&#x1f95d; 另外&#xff0c;如果你对…

妨碍有效沟通的5种认知谬误

在沟通过程中&#xff0c;某些认知谬误会让我们看不到真实的信息&#xff0c;做出错误的沟通行为&#xff0c;既伤害到自己和别人&#xff0c;又无法达成真正的沟通意图。本文介绍了5种认知谬误&#xff0c;如果在沟通过程中能够有效识别并克服这些谬误&#xff0c;就可以达成更…

【MySQL系列 03】事务隔离:为什么你改了我还看不见?

事务 今天的文章里&#xff0c;我会以 InnoDB 为例&#xff0c;剖析 MySQL 在事务支持方面的特定实现&#xff0c;并基于原理给出相应的实践建议&#xff0c;希望这些案例能加深你对 MySQL 事务原理的理解。 说到事务&#xff0c;银行应用是解释事务必要性的一个经典例子&…

STL - hash

1、unordered系列关联式容器 在C98中&#xff0c;STL提供了底层为红黑树结构的一系列关联式容器&#xff0c;在查询时效率可达到O()&#xff0c;即最差情况下需要比较红黑树的高度次&#xff0c;当树中的节点非常多时&#xff0c;查询效率也不理想。最好 的查询是&#xff0c;进…

Java面试题之分布式/微服务篇

经济依旧不景气啊&#xff0c;如此大环境下Java还是这么卷&#xff0c;又是一年一次的金三银四。 兄弟们&#xff0c;你准备好了吗&#xff1f;冲冲冲&#xff01;欧里给&#xff01; 分布式/微服务相关面试题解 题一&#xff1a;CAP理论&#xff0c;BASE理论题二&#xff1a;…

智慧餐饮系统架构的设计与实现

随着科技的不断发展&#xff0c;智慧餐饮系统在餐饮行业中扮演着越来越重要的角色。智慧餐饮系统整合了信息技术&#xff0c;以提高餐饮企业的管理效率、客户服务质量和市场竞争力。本文将探讨智慧餐饮系统架构的设计与实现&#xff0c;并探讨其在餐饮行业中的应用前景。 架构…

Spring框架@Autowired注解进行字段时,使用父类类型接收子类变量,可以注入成功吗?(@Autowired源码跟踪)

一、 前言 平常我们在使用spring框架开发项目过程中&#xff0c;会使用Autowired注解进行属性依赖注入&#xff0c;一般我们都是声明接口类型来接收接口实现变量&#xff0c;那么使用父类类型接收子类变量&#xff0c;可以注入成功吗&#xff1f;答案是肯定可以的&#xff01;…

springboot访问webapp下的jsp页面

一&#xff0c;项目结构。 这是我的项目结构&#xff0c;jsp页面放在WEB-INF下的page目录下面。 二&#xff0c;file--->Project Structure,确保这两个地方都是正确的&#xff0c;确保Source Roots下面有webapp这个目录&#xff08;正常来说&#xff0c;应该本来就有&#…

职场数据汇总,数据汇报,动态大屏可视化制作与数据分析

大屏可视化模板。 动态大屏可视化的制作方法与详细操作步骤 如下操作&#xff1a; AIGC ChatGPT 职场案例 AI 绘画 与 短视频制作 PowerBI 商业智能 68集 Mysql 8.0 54集 Oracle 21C 142集 Office 2021实战应用 Python 数据分析实战&#xff0c; ETL Informatica 数据仓库案…

优化设备维修流程:易点易动设备管理系统的设备维修管理策略

在现代企业中&#xff0c;设备是生产运营的核心要素之一。然而&#xff0c;设备故障和维修常常给企业带来诸多困扰和成本。为了提高设备维修的效率和质量&#xff0c;许多企业开始采用先进的设备管理系统。在这方面&#xff0c;易点易动设备管理系统以其卓越的设备维修管理策略…

SpringBoot线上打包

背景&#xff1a; 1.我们打包时其实需要很多资源打到jar包之外&#xff0c;这样子修改了配置后&#xff0c;就可以生效了。 2.包的命名: 以mj为例子&#xff1a; 业务层&#xff1a; com.jn.mj // 这个是这个工程的总包名 com.jn.mj.gateway // web服集群 c…

【目标航迹管理(1)】基于d-s证据理论信息融合的多核目标跟踪方法

1 引言&#xff1a;从航机起始方法开始 我们为什么会有这个议题&#xff1f;因为航机起始方法。 处理目标航迹起始的方法主要分为两大类&#xff1a;批处理和序贯。 在杂波密度比较高的环境下&#xff0c;比如有红外卫星或地面雷达监视区域&#xff0c;则选用批处理方法&…

记录解决uniapp使用uview-plus在vue3+vite+ts项目中打包后样式不能显示问题

一、背景 从 vue2uview1 升级到 vue3vitetsuview-plus ,uview组件样式打包后不显示&#xff0c;升级前uview 组件是可以正常显示&#xff0c;升级后本地运行是可以正常显示&#xff0c;但是打包发布成H5后uview的组件无法正常显示&#xff0c;其他uniapp自己的组件可以正常显示…

jetson是什么?

jetson是NVIDIA推出的一系列嵌入式计算板&#xff0c;用于边缘计算。这些板卡专门设计用于加速机器学习和计算机视觉等人工智能应用&#xff0c;它们是低功耗设备&#xff0c;非常适合在无人机、机器人、智能相机和其他自动化设备中使用。 Jetson板卡包括多种型号&#xff0c;…

消息中间件之RocketMQ源码分析(十二)

Namesrv启动流程 Broker启动流程 BrokerStartup.java类主要负责为真正的启动过程做准备&#xff0c;解析脚本传过来的参数&#xff0c;初始化Broker配置&#xff0c;创建BrokerController实例等工作。BrokerController.java类是Broker的掌控者&#xff0c;它管理和控制Broker的…