用 TensorFlow.js 在浏览器中训练一个计算机视觉模型(手写数字分类器)

news2024/10/7 10:22:30

文章目录

  • Building a CNN in JavaScript
  • Using Callbacks for Visualization
  • Training with the MNIST Dataset
  • Running Inference on Images in TensorFlow.js
  • References


我们在《在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)》中已经对 TensorFlow.js 的使用有了大致的了解,现在我们进一步来看如何训练一个图片数据集,并做一些可视化工作。文章代码可从《AI and Machine Learning for Coders》一书 GitHub 找到。

在使用浏览器时,每当我们在一个 URL 上打开一个资源时,就会建立一个 HTTP 连接。我们用这个连接把命令传给服务器,然后服务器就会把结果回传。当涉及到计算机视觉时,我们通常会有大量的训练数据。例如,MNIST 和 Fashion MNIST,尽管它们已经是非常小型的图片数据集,但它们仍然包含了 70,000 张图片,这将是 70,000 个 HTTP 连接!这显然会造成大量的开销,稍后我们看如何处理这个问题。

Building a CNN in JavaScript

我们看在 keras 中定义的如下 针对手写数字数据集的 CNN 模型如何在 JavaScript 中定义:

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, (3, 3),activation='relu', 
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

我们分别来看看卷积层、池化层、全连接层是如何在 JavaScript 中定义的:

我们首先将模型定义为 sequential

model = tf.sequential();

第一个卷积层:

model.add(tf.layers.conv2d({inputShape: [28, 28, 1],
                            kernelSize: 3,
                            filters: 64,
                            activation: 'relu'}));

第一个池化层:

model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));

第一个全连接层:

model.add(tf.layers.dense({units: 128, activation: 'relu'}));

因此完整的 JavaScript 定义为:

    model = tf.sequential();
        
    model.add(tf.layers.conv2d({inputShape: [28, 28, 1],
                                kernelSize: 3,
                                filters: 64,
                                activation: 'relu'}));
        
    model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
        
    model.add(tf.layers.conv2d({kernelSize: 3,
                                filters: 64,
                                activation: 'relu'}));
        
    model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
        
    model.add(tf.layers.flatten());
        
    model.add(tf.layers.dense({units: 128, activation: 'relu'}));
        
    model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

编译模型时的语法:

model.compile({optimizer: tf.train.adam(),
               loss: 'categoricalCrossentropy',
               metrics: ['accuracy']});

Using Callbacks for Visualization

我们直接使用《在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)》的现成代码来进行演示,代码如下:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    
    <script lang="js">
    
        async function run(){
            const csvUrl = 'iris.csv';
            const trainingData = tf.data.csv(csvUrl, {
                columnConfigs: {
                    species: {
                        isLabel: true
                    }
                }
            });
            
            const convertedData = trainingData.map(({xs, ys}) => {
                const labels = [
                    ys.species == 'setosa' ? 1 : 0,
                    ys.species == 'virginica' ? 1: 0,
                    ys.species == 'versicolor' ? 1 : 0
                ]
                return {xs: Object.values(xs), ys: Object.values(labels)};
            }).batch(10);
        
            const numOfFeatures = (await trainingData.columnNames()).length - 1;
        
            const model = tf.sequential();
            model.add(tf.layers.dense({inputShape: [numOfFeatures],
                                       activation: "sigmoid", units: 5}));
        
            model.add(tf.layers.dense({activation: "softmax", units: 3}));
        
            model.compile({loss: "categoricalCrossentropy",
                           optimizer: tf.train.adam(0.06)});
        
            await model.fitDataset(convertedData,
                                   {epochs:100,
                                    callbacks:{
                                        onEpochEnd: async(epoch, logs) =>{
                                            console.log("Epoch: " + epoch + " Loss: " + logs.loss);
                                    }
                                }});
            const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
            alert(model.predict(testVal));
        
        }
        
        run();
        
    </script>
    
<body>
    <h1>Iris Classifier</h1>
</body>
</html>

为了使用可视化工具 tfjs-vis,我们需要在代码中添加如下的 script 标签:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

并用 tfvis.show 定义一个回调,来在训练时进行可视化:

const metrics = ['loss', 'accuracy'];
            
const container = {name: 'Model Training', 
                   styles: {height: '640px'},
                   tab: 'Training Progress'};
            
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

将原有代码中的回调替换成 fitCallbacks,现在我们的完整代码为:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>
    
    <script lang="js">
    
        async function run(){
            const csvUrl = 'iris.csv';
            const trainingData = tf.data.csv(csvUrl, {
                columnConfigs: {
                    species: {
                        isLabel: true
                    }
                }
            });
            
            const convertedData = trainingData.map(({xs, ys}) => {
                const labels = [
                    ys.species == 'setosa' ? 1 : 0,
                    ys.species == 'virginica' ? 1: 0,
                    ys.species == 'versicolor' ? 1 : 0
                ]
                return {xs: Object.values(xs), ys: Object.values(labels)};
            }).batch(10);
        
            const numOfFeatures = (await trainingData.columnNames()).length - 1;
        
            const model = tf.sequential();
            model.add(tf.layers.dense({inputShape: [numOfFeatures],
                                       activation: "sigmoid", units: 5}));
        
            model.add(tf.layers.dense({activation: "softmax", units: 3}));
        
            model.compile({loss: "categoricalCrossentropy",
                           optimizer: tf.train.adam(0.06)});
            
            const metrics = ['loss', 'accuracy'];
            
            const container = {name: 'Model Training', 
                               styles: {height: '640px'},
                               tab: 'Training Progress'};
            
            const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
        
            await model.fitDataset(convertedData,
                                   {epochs: 50,
                                    callbacks: fitCallbacks});
            

        
        }
        
        run();
        
    </script>
    
<body>
    <h1>Iris Classifier</h1>
</body>
</html>

运行之后,会有如下结果:

在这里插入图片描述


Training with the MNIST Dataset

我们先如《在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)》一样新建一个项目,并在保持项目的本地路径复制一个 script.js 副本,将其命名为 data.js。下面列出的代码都会放在 data.js 中。

在 TensorFlow.js 中,处理数据训练的一个特殊的方法是将所有的图像附加在一起,成为一个单一的图像,通常称为 sprite sheet,而不是逐个下载每个图像。这种技术通常在游戏开发中使用,游戏的图形被存储在一个文件中,而不是多个小文件,以提高文件存储效率。如果我们把训练用的所有图片都存储在一个文件中,我们只需要打开一个 HTTP 连接,就可以一次性下载所有图片。例如,MNIST 的 sprite sheet 如下图所示:

在这里插入图片描述
这幅图片的维度为 65000×784(28×28),也就是说,我们只需逐行读取该图片文件,就能得到一张张 28×28 像素的图片。

我们可以在 JavaScript 中先将图像加载,然后定义一个画布(canvas),在从原始图像中提取出各个“线条”(行)后,在画布上画出这些“线条”。然后,画布上的字节可以被提取到一个数据集中用于训练。下面我们看具体流程:

训练集测试集比例为 5:1

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;
        
const TRAIN_TEST_RATIO = 5/6;
        
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
        
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

定义 canvas:

const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');

图片地址:

img.src = "https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png";

一旦图像被加载,我们就可以设置一个 buffer 来容纳其中的字节。该图像是一个 PNG 文件,每个像素有 4 个字节,所以需要为 buffer 预留 65,000×768×4 个字节。我们不需要逐个图像的分割文件,而是可以分块(chunks)分割。通过指定 chunkSize,我们可以一次取五千张图片:

img.onload = () => {
    img.width = img.naturalWidth;
    img.height = img.naturalHeight;
            
    const datasetBytesBuffer = 
          new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
            
     const chunkSize = 5000;
     canvas.width = img.width;
     anvas.height = chunkSize;

下面我们通过一个 for 循环来将图片读入 buffer 中,因为图片为灰度图,所以 R\G\B 三个通道的值都是一样的,我们任意选择其中之一:

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
    const datasetBytesView = new Float32Array(
        datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
        IMAGE_SIZE * chunkSize);
            
    ctx.drawImage(
        img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);
                
    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
                
    for (let j = 0; j < imageData.data.length / 4; j++) {
        datasetBytesView[j] = imageData.data[j * 4] / 255;
    }
}
            
this.datasetImages = new Float32Array(datasetBytesBuffer);

和图片类似,标签也是被存储在一个单独的文件中。这是一个二进制文件,对标签进行了稀疏编码。每个标签由 10 个字节表示,其中一个字节的值为 01,代表某个类别。因此,除了逐行下载和解码图像的字节外,我们还需要对标签进行解码。我们使用 arrayBuffer 将标签解码成整数数组。

const labelsRequest = fetch("https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8");
        
const [imgResponse, labelsResponse] = 
       await Promise.all([imgRequest, labelsRequest]);
        
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

然后我们就可以划分训练集和测试集:

this.trainImages = 
	this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = 
	this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
        
this.trainLabels = 
	this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels = 
	this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);

和常规流程一样,我们也可以对数据集进行分批打包(batch):

nextBatch(batchSize, data, index) {
	const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
	const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
	
	for (let i = 0; i < batchSize; i++) {
		const idx = index();
		
		const image =
			data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
		batchImagesArray.set(image, i * IMAGE_SIZE);
		
		const label =
			data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
		batchLabelsArray.set(label, i * NUM_CLASSES);
	}
	
	const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
	const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
	
	return {xs, labels};
}

然后,训练数据可以使用下面这个批处理函数来返回所需批次大小的且打乱顺序的训练批次:

nextTrainBatch(batchSize) {
	return this.nextBatch(
		batchSize, [this.trainImages, this.trainLabels], () => {
			this.shuffledTrainIndex =
				(this.shuffledTrainIndex + 1) % this.trainIndices.length;
			return this.trainIndices[this.shuffledTrainIndex];
		});
}

测试集数据的处理方式是完全一样的,我们在下面的完整代码中给出。

整个 data.js 文件中代码如下:我们定义了一个类 MnistData 用来封装我们刚刚定义的所有函数方法。之后在 index.html 中,我们直接从 data.js 文件中导入该类,并进行实例化 const data = new MnistData();,就可以直接调用类中的方法。

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const TRAIN_TEST_RATIO = 5 / 6;

const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    // Make a request for the MNIST sprited image.
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;

        const datasetBytesBuffer =
            new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
              datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
              IMAGE_SIZE * chunkSize);
          ctx.drawImage(
              img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
              chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            // All channels hold an equal value since the image is grayscale, so
            // just read the red channel.
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] =
        await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    // Create shuffled indices into the train/test set for when we select a
    // random dataset element for training / validation.
    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    // Slice the the images and labels into train and test sets.
    this.trainImages =
        this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels =
        this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels =
        this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

  nextTrainBatch(batchSize) {
    return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
          this.shuffledTrainIndex =
              (this.shuffledTrainIndex + 1) % this.trainIndices.length;
          return this.trainIndices[this.shuffledTrainIndex];
        });
  }

  nextTestBatch(batchSize) {
    return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
      this.shuffledTestIndex =
          (this.shuffledTestIndex + 1) % this.testIndices.length;
      return this.testIndices[this.shuffledTestIndex];
    });
  }

  nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
      const idx = index();

      const image =
          data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
      batchImagesArray.set(image, i * IMAGE_SIZE);

      const label =
          data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
      batchLabelsArray.set(label, i * NUM_CLASSES);
    }

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};
  }
}

下面定义的回调以及训练代码,我们之后会直接封装到 index.html 中的 train 函数当中。

还记得我们刚刚使用的可视化回调吗?我们这里再将它定义出来:

const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
const container = { name: 'Model Training', styles: { height: '640px' },
					tab: 'Training Progress' };
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

调用函数生成训练、测试数据集:

const [trainXs, trainYs] = tf.tidy(() => {
	const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
	return [
		d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
		d.labels
	];
});

const [testXs, testYs] = tf.tidy(() => {
	const d = data.nextTestBatch(TEST_DATA_SIZE);
	return [
		d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
		d.labels
	];
});

注意这里 tf.tidy 的使用。在 TensorFlow.js 中,它将帮助我们清理所有中间张量,除了那些函数返回的张量。在使用 TensorFlow.js 时,这对防止浏览器中的内存泄漏至关重要。

现在万事俱备,我们就可以进行训练啦!

return model.fit(trainXs, trainYs, {
	batchSize: BATCH_SIZE,
	validationData: [testXs, testYs],
	epochs: 20,
	shuffle: true,
	callbacks: fitCallbacks
});

大家这时可能还跑不出下面的训练过程,别着急,我们一会给出 index.html 文件的完整代码,大家到时直接运行即可。

在这里插入图片描述


Running Inference on Images in TensorFlow.js

推断时我们需要一张测试图片,我们可以直接创建一个画布对象,然后让用户使用鼠标在画布上写出要判断的数字:

rawImage = document.getElementById('canvasimg');
ctx = canvas.getContext("2d");
ctx.fillStyle = "black";
ctx.fillRect(0,0,280,280);

在用户通过下面的 draw 函数写好数字之后:

function draw(e) {
	if(e.buttons!=1) return;
	ctx.beginPath();
	ctx.lineWidth = 24;
	ctx.lineCap = 'round';
	ctx.strokeStyle = 'white';
	ctx.moveTo(pos.x, pos.y);
	setPosition(e);
	ctx.lineTo(pos.x, pos.y);
	ctx.stroke();
	rawImage.src = canvas.toDataURL('image/png');
}

我们从画布上抓取像素,并处理成模型可以处理的输入张量:

var raw = tf.browser.fromPixels(rawImage,1);

var resized = tf.image.resizeBilinear(raw, [28,28]);

var tensor = resized.expandDims(0);

之后我们就可以进行预测:

var prediction = model.predict(tensor);
var pIndex = tf.argMax(prediction, 1).dataSync();

index.thml 文件完整代码:

<html>
<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

</head>
<body>
    <h1>Handwriting Classifier!</h1>
    <canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;"></canvas>
    <img id="canvasimg" style="position:absolute;top:10%;left:52%;width=280;height=280;display:none;">
    <input type="button" value="classify" id="sb" size="48" style="position:absolute;top:400;left:100;">
    <input type="button" value="clear" id="cb" size="23" style="position:absolute;top:400;left:180;">
    <script src="data.js" type="module">
    </script>
    
    
</body>
    
    <script type="module">
        
        import {MnistData} from './data.js';
        var canvas, ctx, saveButton, clearButton;
        var pos = {x:0, y:0};
        var rawImage;
        var model;
	
        function getModel() {
	       model = tf.sequential();

	       model.add(tf.layers.conv2d({inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'}));
	       model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
	       model.add(tf.layers.conv2d({filters: 16, kernelSize: 3, activation: 'relu'}));
	       model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
	       model.add(tf.layers.flatten());
	       model.add(tf.layers.dense({units: 128, activation: 'relu'}));
	       model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

	       model.compile({optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']});

	       return model;
        }

        async function train(model, data) {
	       const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
	       const container = { name: 'Model Training', styles: { height: '640px' }, tab: 'Training Progress'};
	       const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
  
	       const BATCH_SIZE = 512;
	       const TRAIN_DATA_SIZE = 5500;
	       const TEST_DATA_SIZE = 1000;

	       const [trainXs, trainYs] = tf.tidy(() => {
		      const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
		      return [
			     d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
			     d.labels
		      ];
	       });

	       const [testXs, testYs] = tf.tidy(() => {
		      const d = data.nextTestBatch(TEST_DATA_SIZE);
		      return [
                d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
			    d.labels
		      ];
	       });

	       return model.fit(trainXs, trainYs, {
		      batchSize: BATCH_SIZE,
		      validationData: [testXs, testYs],
		      epochs: 20,
		      shuffle: true,
		      callbacks: fitCallbacks
	       });
        }

        function setPosition(e){
	       pos.x = e.clientX-100;
	       pos.y = e.clientY-100;
        }
    
        function draw(e) {
	       if(e.buttons!=1) return;
	       ctx.beginPath();
	       ctx.lineWidth = 24;
	       ctx.lineCap = 'round';
	       ctx.strokeStyle = 'white';
	       ctx.moveTo(pos.x, pos.y);
	       setPosition(e);
	       ctx.lineTo(pos.x, pos.y);
	       ctx.stroke();
	       rawImage.src = canvas.toDataURL('image/png');
        }
    
        function erase() {
	       ctx.fillStyle = "black";
	       ctx.fillRect(0,0,280,280);
        }
    
        function save() {
	       var raw = tf.browser.fromPixels(rawImage,1);
	       var resized = tf.image.resizeBilinear(raw, [28,28]);
	       var tensor = resized.expandDims(0);
           var prediction = model.predict(tensor);
           var pIndex = tf.argMax(prediction, 1).dataSync();
    
	       alert(pIndex);
        }
    
        function init() {
	       canvas = document.getElementById('canvas');
	       rawImage = document.getElementById('canvasimg');
	       ctx = canvas.getContext("2d");
	       ctx.fillStyle = "black";
	       ctx.fillRect(0,0,280,280);
	       canvas.addEventListener("mousemove", draw);
	       canvas.addEventListener("mousedown", setPosition);
	       canvas.addEventListener("mouseenter", setPosition);
	       saveButton = document.getElementById('sb');
	       saveButton.addEventListener("click", save);
	       clearButton = document.getElementById('cb');
	       clearButton.addEventListener("click", erase);
        }


        async function run() {  
	       const data = new MnistData();
	       await data.load();
	       const model = getModel();
	       tfvis.show.modelSummary({name: 'Model Architecture'}, model);
	       await train(model, data);
	       init();
	       alert("Training is done, try classifying your handwriting!");
        }
        
        run();
    
    </script>
    
    

</html>

当运行之后,我们等模型训练完毕,右侧就会出现下面的界面:
在这里插入图片描述

我们直接用鼠标在黑色画布上随便写一个数字,然后点击 classify:

在这里插入图片描述
在这里插入图片描述


References

AI and Machine Learning for Coders by Laurence Moroney.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/66214.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字源表如何测试MOS管?

MOSFET(金属—氧化物半导体场效应晶体管)是 一种利用电场效应来控制其电流大小的常见半导体器件&#xff0c;可 以 广 泛 应 用 在 模 拟 电 路 和 数 字 电 路 当 中 。 MOSFET可以由硅制作&#xff0c;也可以由石墨烯&#xff0c;碳纳米管 等材料制作&#xff0c;是材料及器件…

集成电路模拟版图入门-转行模拟版图基础学习笔记(二)

在众多IC岗位中&#xff0c;模拟版图确实属于容易入门&#xff0c;吸引来很多想要转行IC行业的朋友&#xff0c;但需要掌握的知识点和技巧并不比设计少&#xff0c;属于门槛简单&#xff0c;上手不易&#xff0c;想要自学模拟版图似乎比较困难。 之前为大家分享过移知学员的模…

(十四)笔记.net学习之RabbitMQ工作模式

RabbitMQ在.net中简单使用一、简单模式1.生产者2.消费者二、工作队列模式1.工作队列模式介绍2.生产者发送消息3.消费能力三、发布/订阅模式1.介绍2.生产者3.消费者四、Routing路由模式1.介绍2.生产着3.消费者五、topic 主题模式1.介绍2. 生产者3.消费者一、简单模式 1.生产者 …

MyBatis系列第1篇:MyBatis未出世之前我们那些痛苦的经历

这么多技术&#xff0c;为什么我们选择的是mybatis 不知道大家是否还记得使用jdbc如何操作数据库&#xff1f; 加载驱动、获取连接、拼接sql、执行sql、获取结果、解析结果、关闭数据库&#xff0c;这些操作是纯jdbc的方式必经的一些过程&#xff0c;每次操作数据库都需要写这…

三面:请设计一个虚拟DOM算法吧

一、问题剖析 这不是前几天面试官开局面试官就让我设计一个路由&#xff0c;二面过了&#xff0c;结果今天来个三面。 问你道简单的送分题&#xff1a;设计一个虚拟DOM算法&#xff1f; 好家伙&#xff0c;来吧&#xff0c;先进行问题剖析&#xff0c;谁让我们是卑微的打工人…

学习python基础知识

1、Python 基础语法 计算机组成&#xff1a;硬件、软件、计算机运行程序方式、Python 语言的特点、应用领域、Python IDE、程序注释&#xff1a;单行注释、多行注释&#xff1b;变量的作用、定义、 命名规则、变量的数据类型、查看变量类型、输入和输入函数、算术运算符、赋值…

gazebo中添加动态障碍物

文章目录gazebo 教程gazebo 添加动态障碍物gazebo添加动态障碍物插件gazebo中动态障碍物实时posegazebo 教程 gazebo github https://github.com/gazebosim/gazebo-classic/tree/gazebo9gazebo tutorials https://classic.gazebosim.org/tutorials运行一个空白环境 <sdf v…

深入了解Java中的SQL注入

深入了解Java中的SQL注入 本文以代码实例复现了Java中JDBC及Mybatis框架采用预编译和非预编译时可能存在SQL注入的几种情况&#xff0c;并给予修复建议。 JDBC 首先看第一段代码&#xff0c;使用了远古时期的JDBC并且并没有使用预编译。这种简单的字符串拼接就存在SQL注入 …

一云七芯!ZStack 祝贺上海市金融信创联合攻关基地荣获一等奖

2022年11月初&#xff0c;由上海市总工会、中共上海市经济和信息化工作委员会、上海市经济信息化委员会主办的2022上海城市数字化转型 “智慧工匠”选树、“领军先锋”评选活动信创应用竞赛决赛暨颁奖典礼中&#xff0c;“一云七芯适配验证云平台及服务解决方案”获得信创应用案…

GitHub2022年度前100的Java面试真题高频知识点汇总

前言 这是我在工作、面试中学习并总结到的一些知识点&#xff0c;都是一些比较典型的、面试常常被问到的问题。 如果你平时没有注意去总结的话&#xff0c;那么当你面试被问到的时候可能会是一脸懵圈&#xff0c;就算这个问题你知道怎么回事&#xff0c;但是你平时没有认真总…

【win11内存占用高优化】未运行程序,系统内存占用50以上

这里写自定义目录标题前言打开控制面板找到电源键功能找到快速启动选项&#xff0c;取消勾选&#xff0c;确定win X以管理员身份打开powershell输入如下命令&#xff0c;回车关闭终端完成前言 windows11在未运行任何其他程序的情况下&#xff0c;内存占用超50%&#xff0c;可…

速度收藏,Fiddler详细使用教程出炉!

目录 01、抓取不同类型接口数据 02、数据模拟以及过滤规则 03、如何模拟接口响应数据 04、使用fiddler进行评论接口测试 绵薄之力【软件测试学习资源分享】 01、抓取不同类型接口数据 查看windows本机的IP 配置fiddler 需要保证要抓取的手机与电脑保持同一网段&#xff0…

转换 FLAC、APE 无损音乐格式为 iTunes 支持导入的 M4A 格式

大家知道常见的无损音乐格式有 FLAC、APE、WAV 等这些格式。其中 FLAC (Free Lossless Audio Codec) 格式因为是免费自由的压缩编码、无损压缩&#xff0c;且受到操作系统、软件及硬件的广泛支持。所以是非常流行常见的无损音乐格式。 自 2005 年 Mac OS X v10.4 开始&#xf…

《垃圾回收算法手册 自动内存管理的艺术》——其他分区策略(笔记)

文章目录十、其他分区策略10.1 大对象空间10.1.1 转轮回收器10.1.2 在操作系统支持下的对象移动10.1.3 不包含指针的对象10.2 基于对象拓扑结构的回收器10.2.1 成熟对象空间的回收10.2.2 基于对象相关性的回收10.2.3 线程本地回收10.2.4 栈上分配10.2.5 区域推断10.3 混合标记—…

磨金石教育摄影技能干货分享|传统民居摄影作品欣赏

我们知道在绘画领域有写实和写意之分&#xff0c;写实多用于人像的描绘&#xff0c;写意多用于山水田园画的创作。尤其是在中国传统绘画艺术中&#xff0c;写意简直就是创作的精髓。 写实和写意的区别在于&#xff0c;前者侧重真实还原&#xff0c;后者在于主管情感表达。 建筑…

哈希表、哈希桶(C++实现)

1. 哈希 1.1 概念 哈希&#xff08;hash&#xff0c;中文&#xff1a;散列&#xff1b;音译&#xff1a;哈希&#xff09;&#xff0c;是一种算法思想&#xff0c;又称散列算法、哈希函数、散列函数等。哈希函数能指导任何一种数据&#xff0c;构造出一种储存结构&#xff0c…

机器学习笔记之配分函数(二)——随机最大似然

机器学习笔记之配分函数——随机最大似然引言回顾&#xff1a;对数似然梯度关于∇θL(θ)\nabla_{\theta}\mathcal L(\theta)∇θ​L(θ)的简化基于MCMC求解负相关于书中图像的解释引言 上一节介绍了对包含配分函数的概率分布——使用极大似然估计求解模型参数的梯度(对数似然…

5款高效率,但是名气不大的小众软件

今天推荐5款十分小众的软件&#xff0c;但是每个都是非常非常好用的&#xff0c;用完后觉得不好用你找我。 1.多窗口文件整理——Q-Dir Q-Dir 是一款多窗口文件整理工具&#xff0c;特别适合用户频繁在各个文件夹中跳转进行复制粘贴的文件归档操作。如果你的电脑硬盘中文件已经…

MySQL 数据库的增删查改 (2)

文章目录一. 数据库约束1. 约束类型2.NULL 约束3.UNIQUE 约束4.DEFAULT 约束5. PRIMARY KEY 约束6.FOREIGN KEY 约束二.表的设计三.插入四.查询1.聚合查询2.联合查询3.合并查询本篇文章继承与 MySQL 表的增删改查(1) 一. 数据库约束 1. 约束类型 NOT NULL -- 表示某一行不能…

下载安全证书到jdk中的cacerts证书库

最近在公司遇到访问https请求&#xff0c;JDK返回异常信息的问题。返回如下&#xff1a; java.lang.Exception: java.lang.Exception: sun.security.validator.ValidatorException: PKIX path building failed: sun.security.provider.certpath.SunCertPathBuilderException: u…