Public
Edited
Feb 24, 2024
1 star
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
plot_data = {
const finished = model.history.length === model.max_steps;
const h = finished ? 0.1 : 0.25;
const xMin = d3.min(X, (i) => i[0]) - 1;
const xMax = d3.max(X, (i) => i[0]) + 1;
const yMin = d3.min(X, (i) => i[1]) - 1;
const yMax = d3.max(X, (i) => i[1]) + 1;

const xx = d3.range(xMin, xMax, h);
const yy = d3.range(yMin, yMax, h);

const Xmesh = [];
for (let i = 0; i < xx.length; i++) {
for (let j = 0; j < yy.length; j++) {
Xmesh.push([xx[i], yy[j]]);
}
}

const inputs = Xmesh.map((xrow) => xrow.map((x) => new Value(x)));
const scores = inputs.map((input) => model.model.call(input)[0]);
const Z = scores.map((s) => (s.data > 0 ? 1 : -1));

const decision_boundary = Xmesh.map((point, i) => ({
x: point[0],
y: point[1],
color: scores[i].data
// color: Z[i]
}));
const original_data = X.map((point, i) => ({
x: point[0],
y: point[1],
color: y[i]
}));

return { step: h / 1.9, original_data, decision_boundary };
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
function loss(model, X, y, batchSize = null) {
let Xb, yb;
// inline DataLoader
if (batchSize === null) {
Xb = X;
yb = y;
} else {
// shuffle indexes
const ri = d3.shuffle(Array.from(X.keys())).slice(0, batchSize);
Xb = ri.map((i) => X[i]);
yb = ri.map((i) => y[i]);
}
const inputs = Xb.map((xrow) => xrow.map((x) => new Value(x)));

// forward the model to get scores
const scores = inputs.map((input) => model.call(input)[0]);

// svm "max-margin" loss
const losses = yb.map((yi, i) =>
new Value(1).sub(new Value(yi).mul(scores[i])).relu()
);
const dataLoss = losses
.reduce((acc, l) => acc.add(l), new Value(0))
.div(new Value(losses.length));

// L2 regularization
const alpha = new Value(1e-4);
const regLoss = alpha.mul(
model.parameters().reduce((acc, p) => acc.add(p.mul(p)), new Value(0))
);
const totalLoss = dataLoss.add(regLoss);

// also get accuracy
const accuracy = yb.map((yi, i) => yi > 0 === scores[i].data > 0);

return {
scores,
losses,
dataLoss,
regLoss,
totalLoss,
accuracy: accuracy.filter(Boolean).length / accuracy.length
};
}
Insert cell
model = {
retrain_button;

const max_steps = 100;
const history = [];

// default config in micrograd is [16, 16, 1]: 2 internal layers + 1 result
const model = new MLP(2, [16, 16, 1]);
const r = [];

const progress = { model, max_steps, history, msg: "..." };

// optimization
for (let step = 0; step < max_steps; step++) {
// forward
let { totalLoss, accuracy } = loss(model, X, y);

// backward
model.zeroGrad();
totalLoss.backward();

// update (sgd)
let learningRate = 1.0 - (0.9 * step) / 100;
model.parameters().forEach((p) => {
p.data -= learningRate * p.grad;
});

history.push({ step, loss: totalLoss.data, accuracy });

if (step % 1 === 0) {
const msg = `step ${step}/${max_steps} loss ${totalLoss.data.toFixed(
3
)}, accuracy ${accuracy * 100}%`;
r.push(msg);
progress.msg = msg;
yield progress;
}
}
return progress;
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more