Public
Edited
Sep 10, 2021
Insert cell
Insert cell
drawDigit(0, {source:"data"})
Insert cell
function drawDigit(idx, {source="data",label=true,zoom=8}={}) {
const size=28;
const ctx=DOM.context2d(size*zoom,size*zoom);
ctx.imageSmoothingEnabled=false;
const imgData = ctx.createImageData(size,size);
const offset = idx * IMAGE_SIZE;
const values = {test:testImages, train:trainImages, "data":data.datasetImages}[source]
.slice(offset, offset+IMAGE_SIZE);
for(let k=0;k<IMAGE_SIZE;k++) {
let val=values[k];
if(source=="data") val = Math.floor(255*val);
imgData.data[4*k+0] = val;
imgData.data[4*k+1] = val;
imgData.data[4*k+2] = val;
imgData.data[4*k+3] = 255;
}
createImageBitmap(imgData, 0, 0, size, size).then(bitmap => {
ctx.drawImage(bitmap, 0, 0, size, size, 0, 0, size*zoom, size*zoom);
if(label) {
const labels={test:testLabels, train:trainLabels, "data":data.datasetLabels}[source];
let theLabel=0;
if(source=="data") {
// one-hot decoding.
theLabel=d3.maxIndex(labels.slice(idx*NUM_CLASSES, (idx+1)*NUM_CLASSES));
}
else {
theLabel=labels[idx];
}
ctx.font="24px serif";
ctx.fillStyle="red";
ctx.fillText(''+theLabel,4,zoom*size-4);
}
});
return ctx.canvas;
}
Insert cell
Insert cell
pred=doPrediction(model,data,100)
Insert cell
async function doPrediction(model, data, testDataSize = 10) {
const testData = data.nextTestBatch(testDataSize);
const testxs = testData.xs.reshape([testDataSize, IMAGE_DIM, IMAGE_DIM, 1]);
const labelsT = testData.labels.argMax(-1);
console.log(labelsT);
const probsT = model.predict(testxs);
const bestT = probsT.argMax(-1);
const [labels,best,probs] = await Promise.all([labelsT.array(),bestT.array(),probsT.array()]);
testxs.dispose();
labelsT.dispose();
probsT.dispose();
bestT.dispose();
const entropy=probs.map(P=>P.reduce((S,p) => S - p*Math.log2(p), 0));
return {labels,best,probs, entropy};
}
Insert cell
pred.probs[d3.maxIndex(pred.entropy)]
Insert cell
pred.entropy[d3.maxIndex(pred.entropy)]
Insert cell
pred.probs[d3.minIndex(pred.entropy)]
Insert cell
Math.log2(10)
Insert cell
Insert cell
await train(model, data)
Insert cell
async function train(model, data) {
const BATCH_SIZE = 100; //512;
const TRAIN_DATA_SIZE = 100; //5500;
//const TEST_DATA_SIZE = 1000;

const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
/*
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(TEST_DATA_SIZE);
return [
d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
*/
return model.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
//validationData: [testXs, testYs],
epochs: 10,
shuffle: true,
//callbacks: fitCallbacks
});
}
Insert cell
Insert cell
model=getModel()
Insert cell
function getModel() {
const model = tf.sequential();
// In the first layer of our convolutional neural network we have
// to specify the input shape. Then we specify some parameters for
// the convolution operation that takes place in this layer.
model.add(tf.layers.conv2d({
inputShape: [IMAGE_DIM, IMAGE_DIM, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));

// The MaxPooling layer acts as a sort of downsampling using max values
// in a region instead of averaging.
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
// Repeat another conv2d + maxPooling stack.
// Note that we have more filters in the convolution.
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
// Now we flatten the output from the 2D filters into a 1D vector to prepare
// it for input into our last layer. This is common practice when feeding
// higher dimensional data to a final classification output layer.
model.add(tf.layers.flatten());

// Our last layer is a dense layer which has 10 output units, one for each
// output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
model.add(tf.layers.dense({
units: NUM_CLASSES,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
}));
// Choose an optimizer, loss function and accuracy metric,
// then compile and return the model
const optimizer = tf.train.adam();
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});

return model;
}
Insert cell
Insert cell
trainImages={
const buffer=await FileAttachment("train-images-idx3-ubyte").arrayBuffer();
// Ignore 4 x int32 header
return new Uint8Array(buffer.slice(16));
}
Insert cell
testImages={
const buffer=await FileAttachment("t10k-images-idx3-ubyte").arrayBuffer();
// Ignore 4 x int32 header and only use first 5k images.
return new Uint8Array(buffer.slice(16, 16+5000*IMAGE_SIZE));
}
Insert cell
trainLabels={
const buffer=await FileAttachment("train-labels-idx1-ubyte").arrayBuffer();
// Ignore 2 x int32 header
return new Uint8Array(buffer.slice(8));
}
Insert cell
testLabels={
const buffer=await FileAttachment("t10k-labels-idx1-ubyte").arrayBuffer();
// Ignore 2 x int32 header and only use first 5k labels.
return new Uint8Array(buffer.slice(8,5008));
}
Insert cell
data = {
const data = new MnistData();
await data.load();
return data;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
class MnistData {

constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
}

async load() {
const TRAIN_TEST_RATIO = 5 / 6;
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
const MNIST_IMAGES_SPRITE_PATH=
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH=
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;

const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
chunkSize);

const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);

resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});

const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);

this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}

nextTrainBatch(batchSize) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
});
}

nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
});
}

nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

for (let i = 0; i < batchSize; i++) {
const idx = index();

const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);

const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}

const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

return {xs, labels};
}
}
Insert cell
B = data.nextTestBatch(2)
Insert cell
pt(B.labels)
Insert cell
Insert cell
import {pt} from "@tmcw/pt"
Insert cell
tf = require('@tensorflow/tfjs@3.9.0')
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more