Published
Edited
Feb 24, 2018
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
model_predictions = {
++modelTracker.count
const numObs = inputData.xs.size,
ymin = inputData.ys.min(),
ymax = inputData.ys.max(),
yrange = ymax.sub(ymin),
ynormalized = inputData.ys.sub(ymin).div(yrange);

// We start by initializing weights and biases for our two layers:
// layer from 1d input to hidden layer
const w1_init_var = 2/(1 + hidden_size);
const W1 = dl.variable(dl.randomNormal([1,hidden_size], 0, Math.sqrt(w1_init_var)));
const b1 = dl.variable(dl.zeros([1,hidden_size]));
// layer from hidden layer to 1d output.
const w2_init_var = 2/(hidden_size + 1)
const W2 = dl.variable(dl.randomNormal([hidden_size, 1], 0, Math.sqrt(w2_init_var)));
const b2 = dl.variable(dl.zeros([1,1]));

// Next, we define out model. It takes a scalar input (because out input is 1d)
// and spits out a scalar output.
const netModel = x => dl.tidy(() => {
const first_layer = act_func(x.mul(W1).add(b1))
const second_layer = first_layer.matMul(W2).add(b2)
return second_layer.asScalar()
})

// The objective function: mean squared error between prediction and actual.
const loss = (preds, label) => preds.sub(label).square().mean()
// set our optimizer based on cell above
let optimizer;
switch(optimizer_type){
case 'sgd':
optimizer = dl.train.momentum(learningRate, momentum);
break;
case 'adam':
optimizer = dl.train.adam();
break;
case 'adamax':
optimizer = dl.train.adamax();
break;
}

// Train the model.
async function train(xs, ys, iterations){
for(let i=0; i<numObs; i++){
optimizer.minimize(() => loss(
netModel(extractObs(xs, i)),
extractObs(ys, i)
));
}
}
const runNumAtStart = modelTracker.count;
for (let iter = 0; iter < iterations; iter++) {
if(runNumAtStart !== modelTracker.count) break; // checks to see if we have reinstantiated new run

await train(inputData.xs, ynormalized, iterations);
await dl.nextFrame();
yield {iter, preds: []}
console.log(`Epoch ${iter} | Run ${modelTracker.count}`)
}
const modelOutput = [];
for(let i=0; i<numObs; i++){
modelOutput.push(netModel(extractObs(inputData.xs, i)).mul(yrange).add(ymin).dataSync()[0]);
}
yield {iter: iterations, preds:modelOutput}
try {
yield invalidation;
} finally {
W1.dispose(); b1.dispose();
W2.dispose; b2.dispose();
}
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more