Public
Edited
Jul 7
1 star
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
tf = require("@tensorflow/tfjs@4/dist/tf.min.js")
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
mnistData = {
const mnistData = await (
await (
await FileAttachment("mnist_train_test_data@1.json.zip").zip()
).file("mnist_train_test_data.json")
).json();
mnistData.train.X = mnistData.train.X.slice(0, train_size);
mnistData.train.y = mnistData.train.y.slice(0, train_size);
mnistData.test.X = mnistData.test.X.slice(0, test_size);
mnistData.test.y = mnistData.test.y.slice(0, test_size);
return mnistData;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
// Should not exceed 65,000
train_size = 1000
Insert cell
// Should not exceed 5,000
test_size = 1000
Insert cell
Insert cell
XTrain = tf.tensor(mnistData.train.X, [train_size, 784])
Insert cell
yTrain = tf.oneHot(tf.tensor1d(mnistData.train.y, "int32"), 10)
Insert cell
XTest = tf.tensor(mnistData.test.X, [test_size, 784])
Insert cell
// I don't think this guy gets used because the predict
// function works with a reguar array.
// yTest = tf.oneHot(tf.tensor1d(mnistData.test.y, "int32"), 10)
Insert cell
Insert cell
Insert cell
Insert cell
model = {
const model = tf.sequential();

model.add(tf.layers.reshape({ targetShape: [28, 28, 1], inputShape: [784] }));

model.add(
tf.layers.conv2d({
kernelSize: 3,
filters: 16,
activation: "relu",
kernelInitializer: "varianceScaling"
})
);

model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

model.add(
tf.layers.conv2d({
kernelSize: 3,
filters: 32,
activation: "relu",
kernelInitializer: "varianceScaling"
})
);

model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

model.add(tf.layers.flatten());

model.add(
tf.layers.dense({
units: 64,
activation: "relu",
kernelInitializer: "varianceScaling"
})
);

// Optional: dropout to reduce overfitting
model.add(tf.layers.dropout({ rate: 0.25 }));

model.add(
tf.layers.dense({
units: 10,
activation: "softmax",
kernelInitializer: "varianceScaling"
})
);

model.compile({
optimizer: tf.train.adam(0.0005),
loss: "categoricalCrossentropy",
metrics: ["accuracy"]
});

return model;
}
Insert cell
Insert cell
Insert cell
{
if (fit_it) {
const start = new Date();
await model.fit(XTrain, yTrain, {
epochs: 5,
batchSize: 64,
shuffle: true,
validationSplit: 0.1,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(
`Epoch ${epoch + 1}: loss = ${logs.loss.toFixed(
4
)}, acc = ${logs.acc.toFixed(4)}`
);
}
}
});
const finish = new Date();
return `Model fit in ${(finish - start) / 1000} seconds.`;
} else {
return null;
}
}
Insert cell
Insert cell
Insert cell
Insert cell
predictions = evaluate ? model.predict(XTest) : null
Insert cell
incorrect_digits = evaluate
? mnistData.test.X.filter((y, i) => bad.indexOf(i) > -1)
: null
Insert cell
1 - incorrect_digits.length / test_size
Insert cell
bad = evaluate
? d3
.zip(predicted_digits, mnistData.test.y)
.map((a, i) => [i, a])
.filter((A) => A[1][0] != A[1][1])
.map((A) => A[0])
: null
Insert cell
evaluate
? d3.zip(predicted_digits, mnistData.test.y).filter((a) => a[0] == a[1])
.length / test_size
: null
Insert cell
predicted_digits = evaluate ? predictions.argMax(-1).data() : null
Insert cell
Insert cell
Insert cell
{
if (download > 0) {
model.save("downloads://cnn-model");
}
}
Insert cell
Insert cell
// model = {
// const model = tf.sequential();

// // Reshape input from [784] to [28, 28, 1]
// model.add(tf.layers.reshape({ targetShape: [28, 28, 1], inputShape: [784] }));

// // First conv block
// model.add(tf.layers.conv2d({
// kernelSize: 3,
// filters: 32,
// activation: "relu",
// kernelInitializer: "varianceScaling",
// padding: "same"
// }));
// model.add(tf.layers.batchNormalization());
// model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

// // Second conv block
// model.add(tf.layers.conv2d({
// kernelSize: 3,
// filters: 64,
// activation: "relu",
// kernelInitializer: "varianceScaling",
// padding: "same"
// }));
// model.add(tf.layers.batchNormalization());
// model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

// // Third conv block (more capacity)
// model.add(tf.layers.conv2d({
// kernelSize: 3,
// filters: 128,
// activation: "relu",
// kernelInitializer: "varianceScaling",
// padding: "same"
// }));
// model.add(tf.layers.batchNormalization());
// model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

// // Flatten and dense layers
// model.add(tf.layers.flatten());

// model.add(tf.layers.dropout({ rate: 0.4 })); // dropout after conv stack

// model.add(tf.layers.dense({
// units: 128,
// activation: "relu",
// kernelInitializer: "varianceScaling"
// }));
// model.add(tf.layers.batchNormalization());
// model.add(tf.layers.dropout({ rate: 0.4 })); // dropout after dense

// // Output layer
// model.add(tf.layers.dense({
// units: 10,
// activation: "softmax",
// kernelInitializer: "varianceScaling"
// }));

// // Compile the model
// model.compile({
// optimizer: tf.train.adam(0.0005),
// loss: "categoricalCrossentropy",
// metrics: ["accuracy"]
// });

// return model;
// }

Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more