data.js
/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const IMAGE_SIZE = 784; const NUM_CLASSES = 10; const NUM_DATASET_ELEMENTS = 65000; const TRAIN_TEST_RATIO = 5 / 6; const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS); const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; /** * A class that fetches the sprited MNIST dataset and returns shuffled batches. * * NOTE: This will get much easier. For now, we do data fetching and * manipulation manually. */ export class MnistData { constructor() { this.shuffledTrainIndex = 0; this.shuffledTestIndex = 0; } async load() { // Make a request for the MNIST sprited image. const img = new Image(); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const imgRequest = new Promise((resolve, reject) => { img.crossOrigin = ''; img.onload = () => { img.width = img.naturalWidth; img.height = img.naturalHeight; const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { const datasetBytesView = new Float32Array( datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize); ctx.drawImage( img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let j = 0; j < imageData.data.length / 4; j++) { // All channels hold an equal value since the image is grayscale, so // just read the red channel. datasetBytesView[j] = imageData.data[j * 4] / 255; } } this.datasetImages = new Float32Array(datasetBytesBuffer); resolve(); }; img.src = MNIST_IMAGES_SPRITE_PATH; }); const labelsRequest = fetch(MNIST_LABELS_PATH); const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Create shuffled indices into the train/test set for when we select a // random dataset element for training / validation. this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); // Slice the the images and labels into train and test sets. this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); } nextTrainBatch(batchSize) { return this.nextBatch( batchSize, [this.trainImages, this.trainLabels], () => { this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length; return this.trainIndices[this.shuffledTrainIndex]; }); } nextTestBatch(batchSize) { return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => { this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length; return this.testIndices[this.shuffledTestIndex]; }); } nextBatch(batchSize, data, index) { const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); for (let i = 0; i < batchSize; i++) { const idx = index(); const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); batchImagesArray.set(image, i * IMAGE_SIZE); const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); batchLabelsArray.set(label, i * NUM_CLASSES); } const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); return {xs, labels}; } }
script.js
import {MnistData} from './data.js'; var canvas,ctx,saveButton,clearButton; var pos={x:0,y:0}; var rawImage; var model; function getModel() { model=tf.sequential(); model.add(tf.layers.conv2d({inputShape:[28,28,1],kernelSize:3,filters:8,activation:'relu'})); model.add(tf.layers.maxPooling2d({poolSize:[2,2]})); model.add(tf.layers.conv2d({filters:16,kernelSize:3,activation:'relu'})); model.add(tf.layers.maxPooling2d({poolSize:[2,2]})); model.add(tf.layers.flatten()); model.add(tf.layers.dense({units:128,activation:'sigmoid'})); model.add(tf.layers.dense({units:10,activation:'softmax'})); model.compile({optimizer:tf.train.adam(),loss:'categoricalCrossentropy',metrics:['accuracy']}); return model; } async function train(model,data){ const metrics=['loss', 'val_loss', 'acc', 'val_acc']; const container={name:'Model training',styles:{height:'640px'}}; const fitCallbacks=tfvis.show.fitCallbacks(container,metrics); const BATCH_SIZE = 512; const TRAIN_DATA_SIZE = 5500; const TEST_DATA_SIZE = 1000; const [trainXs,trainYs]=tf.tidy(()=> { const d=data.nextTrainBatch(TRAIN_DATA_SIZE); return[ d.xs.reshape([TRAIN_DATA_SIZE,28,28,1]), d.labels ]; }); const [testXs,testYs]=tf.tidy(()=>{ const d=data.nextTestBatch(TEST_DATA_SIZE); return[ d.xs.reshape([TEST_DATA_SIZE,28,28,1]), d.labels ]; }); return model.fit(trainXs,trainYs,{ batchSize:BATCH_SIZE, validationData:[testXs,testYs], epochs:20, shuffle:true, callbacks:fitCallbacks }); } function setPosition(e){ pos.x=e.clientX-100; pos.y=e.clientY-100; } function draw(e) { if(e.buttons!=1)return ; ctx.beginPath(); ctx.lineWidth=24; ctx.lineCap='round'; ctx.strokeStyle='white'; ctx.moveTo(pos.x,pos.y); setPosition(e); ctx.lineTo(pos.x,pos.y) ctx.stroke(); rawImage.src=canvas.toDataURL('image/png'); } function erase() { ctx.fillStyle="black"; ctx.fillRect(0,0,280,280); } function save() { var raw=tf.browser.fromPixels(rawImage,1); var resized=tf.image.resizeBilinear(raw,[28,28]); var tensor=resized.expandDims(0); var prediction=model.predict(tensor); var pIndex=tf.argMax(prediction,1).dataSync(); alert(pIndex); } function init() { canvas=document.getElementById('canvas'); rawImage=document.getElementById('canvasimg'); ctx=canvas.getContext("2d"); ctx.fillStyle="black"; ctx.fillRect(0,0,280,280); canvas.addEventListener("mousemove",draw); canvas.addEventListener("mousedown",setPosition); canvas.addEventListener("mouseenter",setPosition); saveButton=document.getElementById('sb'); saveButton.addEventListener("click",save); clearButton=document.getElementById('cb'); clearButton.addEventListener("click",erase); } async function run() { const data=new MnistData(); await data.load(); const model=getModel(); tfvis.show.modelSummary({name:'Model Architecture'},model); await train(model,data); init(); alert("Training is done, try classifying..."); } document.addEventListener('DOMContentLoaded', run);
mnist.htm
<html> <head> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script> </head> <body> <h1> Handwritten character recognition</h1> <canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;"></canvas> <img id="canvasimg" style="position:absolute;top:10%;left=52%;width:280;height=280;display:none;"> <input type="button" value="classify" id="sb" size="48" style="position:absolute;top:400;left:100;"> <input type="button" value="clear" id="cb" size="23" style="position:absolute;top:400;left:180;"> <script src="data.js" type="module"></script> <script src="script.js" type="module"></script> </body> </html>
我试图制作一个手写数字分类器,该分类器根据我们在网页画布上绘制的内容来识别数字。但是在训练时,我的损失曲线饱和到1.0,而我的准确度饱和到60%。因此,我尝试将128个节点密集层的激活功能从relu更改为Sigmoid。即使更改后,我的损失也达到了1.0。请帮帮我。
data.js / ** * @license *版权所有2018 Google LLC。版权所有。 *根据Apache许可2.0版(“许可”)获得许可; *除合规性外,您不得使用此文件...