在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)

news2024/9/22 19:24:22

文章目录

  • 开发环境
  • 构建第一个 TensorFlow.js 模型
  • 构建鸢尾花数据集分类器
  • References


在 《TensorFlow Lite 是什么?用 TensorFlow Lite 来转换模型(附代码)》中我们已经介绍了可以帮助 TensorFlow 模型在移动设备以及嵌入式设备中运行的 TensorFlow Lite,TensorFlow 生态系统中还包括 TensorFlow.js,它可以帮助我们使用现成的 JavaScript 模型或转换 Python TensorFlow 模型以在浏览器中或 Node.js 下运行。

下面这张图总结了整个 TensorFlow 的生态系统:

在这里插入图片描述

和 TensorFlow Lite 不同的是,TensorFlow.js 还可以用来训练模型。它可以让我们在 JavaScript 中使用类似 keras 的代码语法,非常友好。

开发环境

有能力的朋友可以在任何 web/JavaScript 开发环境下来进行尝试,我们这里直接使用 brackets 官网给出的线上代码编辑器 Phoenix (进入官网后就会自动弹出提示)来进行演示。

进入 Phoenix 后会显示如下画面:

在这里插入图片描述

我们直接将 index.html 文件的内容清空,并先加入以下的大框架:

<html>
<head></head>
<body>
	<h1>First HTML Page</h1>
</body>
</html>

构建第一个 TensorFlow.js 模型

<head> 以及 <body> 标签之间,我们添加下面的 script 标签来指定 TensorFlow.js 库的位置:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>

之后,我们在第一个 script 标签后添加第二个 script 标签,里面要定义我们的模型,语法和 python 非常相似,但要记得在结尾添加分号:

    <script lang="js">
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); 

接下来我们在 script 中添加我们的输入输出数据。JavaScript 里当然是没有 Numpy 数组的,所以我们使用 tf.tensor2d 来替代,但要注意在数据数组后还有第二个数组来指明数据的形状:

        const xs = tf.tensor2d([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], [6, 1]);
        const ys = tf.tensor2d([-3.0, -1.0, 2.0, 3.0, 5.0, 7.0], [6, 1]);

下一步,我们加入训练函数:

        async function doTraining(model){
            const history = 
                  await model.fit(xs, ys,
                                  {epochs: 500,
                                   callbacks: {
                                       onEpochEnd: async(epoch, logs) =>{
                                           console.log("Epoch:"
                                                       + epoch
                                                       + " Loss:"
                                                       + logs.loss);
                                       }
                                   }});
        }

训练将会花费一段时间,所以我们最好将它设定为一个异步函数。然后我们等待 model.fit 这个异步方法执行完成。我们传入了 epochs 参数,并指定了一个回调以在每个 epoch 结束后报告训练损失。

最后我们需要做的就是调用 doTraining 方法,将模型传入其中:

        doTraining(model).then(() => {
            alert(model.predict(tf.tensor2d([10], [1, 1])));
        });

完整的代码下面给出:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <script lang="js">
        async function doTraining(model){
            const history = 
                  await model.fit(xs, ys,
                                  {epochs: 500,
                                   callbacks: {
                                       onEpochEnd: async(epoch, logs) =>{
                                           console.log("Epoch:"
                                                       + epoch
                                                       + " Loss:"
                                                       + logs.loss);
                                       }
                                   }});
        }
        
        const model = tf.sequential();
        model.add(tf.layers.dense({units: 1, inputShape: [1]}));
        model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); 
        model.summary();
        
        const xs = tf.tensor2d([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], [6, 1]);
        const ys = tf.tensor2d([-3.0, -1.0, 2.0, 3.0, 5.0, 7.0], [6, 1]);
        
        
        doTraining(model).then(() => {
            alert(model.predict(tf.tensor2d([10], [1, 1])));
        });
    </script>>
<body>
    <h1>First HTML Page</h1>
</body>
</html>

点击 File -> Save File,代码会自动运行,如果已经保存,可以直接点击预览页面左上方的刷新按钮,稍等几秒(模型训练),会弹出以下对话框:

在这里插入图片描述

这就是输入数据为 [10] 时模型给出的预测结果!如果我们想查看模型每个 epoch 之后打印的训练损失,直接按下快捷键 Ctrl-Shift-I,并在弹出的面板上方选择 Console,就会有如下结果:

在这里插入图片描述


下面我们训练一个稍微复杂点的模型。


构建鸢尾花数据集分类器

鸢尾花数据集(.csv)共有 150 条数据,每条数据有 4 个特征(sepal length、sepal width、petal length、petal width),对应三种鸢尾花(setosa、versicolor、virginical)。鸢尾花数据集很容易找到,也可以从我这里下载:Iris 鸢尾花数据集(.csv 格式)。

通过常规的机器学习方法,我们可以对数据集做一些可视化,加深认识:

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
df = pd.read_csv('../input/iris-dataset/iris.csv')
df.head(5)
"""
   sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa
"""

我们可以通过 sns.pairplot() 画出两两特征之间的关系,且用种类进行划分:

sns.pairplot(df, kind = 'scatter', hue = 'species')
plt.show()

在这里插入图片描述
对角线上为每个种类在某个特征上的分布图,非对角线上则是两个特征选取不同值时对应的鸢尾花种类。

下面我们就开始在 Phoenix 中进行训练吧!

我们点击左上角的新建项目,在本地选择路径,创建新项目,然后将我们下载的 iris 数据集拖入我们项目保存的路径。
在这里插入图片描述

和之前一样,我们先添加以下大框架:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<body>
    <h1>Iris Classifier</h1>
</body>
</html>

我们可以使用 TensorFlow.js 的 tf.data.csv 来加载 CSV 文件,且可以通过它来指定标签对应的列:

    <script lang="js">
    
        async function run(){
            const csvUrl = 'iris.csv';
            const trainingData = tf.data.csv(csvUrl, {
                columnConfigs: {
                    species: {
                        isLabel: true
                    }
                }
            });
        }
    
    </script>

species 对应的是种类名称的字符串,我们需要先将它转换为数值。我们这里使用独热编码来转换标签:

        const convertedData = trainingData.map(({xs, ys}) => {
            const labels = [
                ys.species == 'setosa' ? 1 : 0,
                ys.species == 'virginica' ? 1: 0,
                ys.species == 'versicolor' ? 1 : 0
            ]
            return {xs: Object.values(xs), ys: Object.values(labels)};
        }).batch(10);

上述代码会将 ‘setosa’ 编码为 [1, 0, 0],将 ‘virginica’ 编码为 [0, 1, 0],而将 ‘versicolor’ 编码为 [0, 0, 1],并返回和之前一样的数据集,除了 species 列的字符串已经被编码为独热向量。

下面我们定义并编译模型,输入层形状为输入特征数(列数减 1),输出层有 3 个神经元:

        const numOfFeatures = (await trainingData.columnNames()).length - 1;
        
        const model = tf.sequential();
        model.add(tf.layers.dense({inputShape: [numOfFeatures],
                                   activation: "sigmoid", units: 5}));
        
        model.add(tf.layers.dense({activation: "softmax", units: 3}));
        
        model.compile({loss: "categoricalCrossentropy",
                       optimizer: tf.train.adam(0.06)});

和之前不同,我们的数据是以数据集的形式组织的,所以训练时我们要使用 fitDataset 方法:

        await model.fitDataset(convertedData,
                               {epochs:100,
                                callbacks:{
                                    onEpochEnd: async(epoch, logs) =>{
                                        console.log("Epoch: " + epoch + " Loss: " + logs.loss);
                                    }
                                }});

如果要测试模型,我们可以使用之前用到的 tensor2d 来创建一个输入数据:

const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
alert(model.predict(testVal));

我们将完整代码给出:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    
    <script lang="js">
    
        async function run(){
            const csvUrl = 'iris.csv';
            const trainingData = tf.data.csv(csvUrl, {
                columnConfigs: {
                    species: {
                        isLabel: true
                    }
                }
            });
            
            const convertedData = trainingData.map(({xs, ys}) => {
                const labels = [
                    ys.species == 'setosa' ? 1 : 0,
                    ys.species == 'virginica' ? 1: 0,
                    ys.species == 'versicolor' ? 1 : 0
                ]
                return {xs: Object.values(xs), ys: Object.values(labels)};
            }).batch(10);
        
            const numOfFeatures = (await trainingData.columnNames()).length - 1;
        
            const model = tf.sequential();
            model.add(tf.layers.dense({inputShape: [numOfFeatures],
                                       activation: "sigmoid", units: 5}));
        
            model.add(tf.layers.dense({activation: "softmax", units: 3}));
        
            model.compile({loss: "categoricalCrossentropy",
                           optimizer: tf.train.adam(0.06)});
        
            await model.fitDataset(convertedData,
                                   {epochs:100,
                                    callbacks:{
                                        onEpochEnd: async(epoch, logs) =>{
                                            console.log("Epoch: " + epoch + " Loss: " + logs.loss);
                                    }
                                }});
            const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
            alert(model.predict(testVal));
        
        }
        
        run();
        
    </script>
    
<body>
    <h1>Iris Classifier</h1>
</body>
</html>

运行之后,会弹出如下结果:

在这里插入图片描述
我们可以对结果进一步优化,让其显示预测的具体种类:

const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
const prediction = model.predict(testVal);
const pIndex = tf.argMax(prediction, axis=1).dataSync();

const classNames = ["Setosa", "Virginica", "Versicolor"];
alert(classNames[pIndex]);

在这里插入图片描述

References

AI and Machine Learning for Coders by Laurence Moroney.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/57677.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

YMTC X3 NAND 232L 终露真容,全球领先|国产芯之光

上一篇文章&#xff08;芯片级解密YMTC NAND Xtacking 3.0技术&#xff09;&#xff0c;我们结合TechInsights获取芯片级信息梳理了国产NAND芯片厂商YMTC的技术演进之路&#xff0c;从2016公司成立&#xff0c;2018年发布Xtacking 1.0 NAND架构&#xff0c;2019年发布Xtacking …

Kotlin高仿微信-第58篇-开通VIP

Kotlin高仿微信-项目实践58篇详细讲解了各个功能点&#xff0c;包括&#xff1a;注册、登录、主页、单聊(文本、表情、语音、图片、小视频、视频通话、语音通话、红包、转账)、群聊、个人信息、朋友圈、支付服务、扫一扫、搜索好友、添加好友、开通VIP等众多功能。 Kotlin高仿…

[附源码]计算机毕业设计springboot疫情网课管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

微信小程序| 做一款多人实时线上的五指棋联机游戏

&#x1f4cc;个人主页&#xff1a;个人主页 ​&#x1f9c0; 推荐专栏&#xff1a;小程序开发成神之路 --【这是一个为想要入门和进阶小程序开发专门开启的精品专栏&#xff01;从个人到商业的全套开发教程&#xff0c;实打实的干货分享&#xff0c;确定不来看看&#xff1f; …

[附源码]计算机毕业设计新能源汽车租赁Springboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

[附源码]计算机毕业设计疫情物资管理系统Springboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

微信支付商户平台-配置密钥/API安全教程

我们在做小程序获取微信开发时&#xff0c;难免会用到微信支付&#xff0c;我们做微信支付时&#xff0c;商户id和密匙是必不可少的。商户id很容易就能获取到。但是这个密匙的配置就相对而言麻烦了一点。今天就来教大家如何配置位置支付的密匙。 先我们要去注册微信支付的账号…

Lattice库联合ModelSim仿真FIFO

Lattice联合ModelSim仿真FIFO前言一、添加IP二、库文件添加&#xff08;一&#xff09;方式一&#xff1a;添加器件库到ModelSim&#xff08;二&#xff09;方法二&#xff1a;直接添加器件库到Libray,和tb.v在同一个目录下仿真三、仿真&#xff08;一&#xff09;仿真文件&…

JAVA社区疫情防控系统毕业设计,社区疫情防控管理系统设计与实现,毕设作品参考

功能清单 【后台管理员功能】 关于我们设置&#xff1a;设置学校简介、联系我们、加入我们、法律声明、学校详情 广告管理&#xff1a;设置小程序首页轮播图广告和链接 留言列表&#xff1a;所有用户留言信息列表&#xff0c;支持删除 会员列表&#xff1a;查看所有注册会员信…

b站黑马JavaScript的Ajax案例代码——新闻列表案例

目录 目标效果&#xff1a; 重点原理&#xff1a; 1.js中art-template标准语法的循环输出 2.js中split方法——转换字符串为数组 3.js中art-template标准语法的过滤器 4.js中Date内置对象——getFullYear() 5.js中Date内置对象——getMonth() 6.js中Date内置对象——ge…

简单认识一下HotSpot 垃圾收集器

前言 HotSpot 虚拟机提供了多种垃圾收集器&#xff0c;每种收集器都有各自的特点&#xff0c;虽然我们要对各个收集器进行比较&#xff0c;但并非为了挑选出一个最好的收集器。我们选择的只是对具体应用最合适的收集器。 新生代垃圾收集器 Serial 垃圾收集器&#xff08;单线…

java.lang.ClassNotFoundException: com.mysql.cj.jdbc.Driver解决方案

&#x1f31f;问题解析 首先&#xff0c;此报错会出现在两种情况&#xff0c;并且有各自的解决方法。 如果在Java程序中报错&#xff0c;那么我们就参考方法1&#xff08;单Java程序&#xff09;&#xff1a; 如果你是在Tomcat中报错&#xff0c;那么我们可以参考方法2&#…

[附源码]JAVA毕业设计交通事故档案管理系统(系统+LW)

[附源码]JAVA毕业设计交通事故档案管理系统&#xff08;系统LW&#xff09; 目运行 环境项配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目…

微信小程序实现微信支付的相关操作设置

本文不涉及相关API的实现&#xff0c;旨在记录实现微信支付需要在微信公众平台和微信支付的商户平台需要进行的操作。 1.首先需要用户申请了微信小程序和入驻微信商户平台 2.获取小程序的appid 设置AppSecre小程序密钥 3.微信支付获取商户号&#xff0c;在认证的时候设置操…

基于JavaSwing的员工工资管理系统

开发环境 eclipsejdk1.8mysql5.7 系统简介 本项目是主要功能有员工信息管理&#xff0c;部门信息管理&#xff0c;员工工资设定&#xff0c;系统设置等&#xff0c;员工不需要登录系统&#xff0c;可以直接查询自己的工资&#xff0c;具体项目操作及项目结构请看演示视频&am…

架构解析:Dubbo3 应用级服务发现如何应对双 11 百万集群实例

继业务全面上云后&#xff0c;今年双 11&#xff0c;阿里微服务技术栈全面迁移到以 Dubbo3 为代表的云上开源标准中间件体系。在业务上&#xff0c;基于 Dubbo3 首次实现了关键业务不停推、不降级的全面用户体验提升&#xff0c;从技术上&#xff0c;大幅提高研发与运维效率的同…

【POJ No. 1019】数字序列 Number Sequence

【POJ No. 1019】数字序列 Number Sequence 北大OJ 题目地址 【题意】 给出单个正整数i &#xff0c;编写程序以找到位于数字组S 1 , S 2 , …, Sk 序列中第i 位上的数字。每个组Sk 都由一系列正整数组成&#xff0c;范围为1&#xff5e;k &#xff0c;一个接一个地写入。 序…

Sass扫码点餐源码 单门店多门店餐饮连锁扫码点餐外卖自提系统源码

智慧餐厅扫码点餐小程序系统源码 1. 开发语言&#xff1a;JAVA 2. 数据库&#xff1a;MySQL 3. 原生小程序 4. Sass 模式 5. 带调试视频 本套扫码点餐小程序系统支持多店铺&#xff0c;支持外卖&#xff0c;堂食&#xff0c;扫码点餐、预约桌号、订单语音提醒、会员营销、…

viewport视口的概念

viewport视口的概念 概念详见 MDN&#xff0c;我摘出来对比了下&#xff0c;如下图&#xff1a; 总结&#xff1a; viewport就是当前窗口的可视部分Visual Viewport 视觉视口 就是视口viewport中的可见部分 比如在mobile浏览器中&#xff0c;输入时&#xff0c;弹出的键盘&am…

屏幕开发学习 -- 迪文串口屏

一 前言 最近学习了一款基于图形化开发的屏幕&#xff0c;在摸索一周后&#xff0c;基本熟悉了这款产品的一个开发过程&#xff0c;今天给大家分享一下迪文串口屏和STM32如何建立通讯&#xff0c;有不足之处&#xff0c;还请见谅&#x1f601; 二 迪文屏介绍 1.选型 我用到的…