Published
Edited
Mar 9, 2018
1 star
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
model = {
const hiddenNodes = 64
// set up all our weights layer by layer
const w1_init_var = 2/(1 + hiddenNodes);
const W1 = dl.variable(dl.randomNormal([1,hiddenNodes], 0, Math.sqrt(w1_init_var)));
const b1 = dl.variable(dl.zeros([1,hiddenNodes]));
const w2_init_var = 2/(hiddenNodes + 1);
const W2 = dl.variable(dl.randomNormal([hiddenNodes,hiddenNodes], 0, Math.sqrt(w2_init_var)));
const b2 = dl.variable(dl.zeros([1,1]));
const W3 = dl.variable(dl.randomNormal([hiddenNodes,1], 0, Math.sqrt(w2_init_var)));
const b3 = dl.variable(dl.zeros([1,1]));

// the model itself
const model = xs => dl.tidy(() =>
xs.mul(W1).add(b1).tanh()
.matMul(W2).add(b2).tanh()
.matMul(W3).add(b3).asScalar()
)
const loss = (ypred, y) => ypred.sub(y).square().mean()
const optimizer = dl.train.adam();
// Updates gradient with passed data.
async function train_on_batch(x, y){
const size_of_batch = x.shape[0];
for(let i=0; i<size_of_batch; i++){
optimizer.minimize(() => loss(
model(extractObs(x, i)),
extractObs(y, i)
));
}
}
async function predict(x){
const size_of_batch = x.shape[0];
const prediction = new Array(size_of_batch);
for(let i=0; i<size_of_batch; i++){
prediction[i] = (await model(extractObs(x, i)).data())[0]
}
return prediction
}
return {model, train_on_batch, predict, weights:{W1, b1, W2, b2, W3, b3}}
}
Insert cell
Insert cell
gen_task = () => {
const phase = dl.scalar(d3.randomUniform(0, 2*Math.PI)())
const amplitude = dl.scalar(d3.randomUniform(0.1, 5)())
return (x) => x.add(phase).sin().mul(amplitude)
}
Insert cell
x_all = linspace(-5, 5, 50)
Insert cell
Insert cell
Insert cell
f_plot = gen_task()
Insert cell
Insert cell
xtrain_plot = sample(x_all.dataSync(), ntrain)
Insert cell
ytrain_plot = f_plot(xtrain_plot)
Insert cell
Insert cell
plot_data = {
const weightnames = ['W1', 'b1', 'W2', 'b2', 'W3', 'b3'];

for(let iteration = 2; iteration < niterations+1; iteration++){
const weights_before = {};
weightnames.forEach(name => {
weights_before[name] = model.weights[name].clone()
});

// generate task
const f = gen_task()
const y_all = f(x_all)

// do sgd on the current task
for(let i = 0; i < innerepochs; i++){
await model.train_on_batch(x_all, y_all)
}

// compute stepsize for this iteration
const currentstepsize = outerstepsize * (1 - iteration / niterations);

const weights_after = model.weights;

// update the weights.
weightnames.forEach(name => {
model.weights[name] = dl.variable(
weights_before[name].add(
weights_after[name].sub(weights_before[name])
).mul(dl.scalar(currentstepsize))
)
})

// every few iterations try retraining the few-shot model
if(iteration%updateFrequency == 0){
// save snapshot of weights again before training on smaller data
const weights_before_2 = {};
weightnames.forEach(name => {
weights_before_2[name] = model.weights[name].clone()
})

let tuning_preds;
// let preds_by_iteration = true_curve;
let preds_by_iteration = [];
// inner tuning iteration loop
for(let j = 0; j < 33; j++){
if(j > 0){ // skip training for 0th iteration.
await model.train_on_batch(xtrain_plot, ytrain_plot)
}
if(j%16 === 0){
tuning_preds = (await model.predict(x_all)).map((d,i) => ({
x: x_all_plain[i],
pred: d,
iteration: j,
}));

preds_by_iteration = [...preds_by_iteration, ...tuning_preds]
}
}

// restore weights
weightnames.forEach(name => {
model.weights[name] = dl.variable(weights_before_2[name])
})
// predict with the untuned model
yield {iteration, predictions: preds_by_iteration}
}
}
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
linspace = (start, stop, steps) => dl.tensor(d3.range(start, stop, (stop-start)/steps))
Insert cell
Insert cell
Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more