model = {
const hiddenNodes = 64
const w1_init_var = 2/(1 + hiddenNodes);
const W1 = dl.variable(dl.randomNormal([1,hiddenNodes], 0, Math.sqrt(w1_init_var)));
const b1 = dl.variable(dl.zeros([1,hiddenNodes]));
const w2_init_var = 2/(hiddenNodes + 1);
const W2 = dl.variable(dl.randomNormal([hiddenNodes,hiddenNodes], 0, Math.sqrt(w2_init_var)));
const b2 = dl.variable(dl.zeros([1,1]));
const W3 = dl.variable(dl.randomNormal([hiddenNodes,1], 0, Math.sqrt(w2_init_var)));
const b3 = dl.variable(dl.zeros([1,1]));
const model = xs => dl.tidy(() =>
xs.mul(W1).add(b1).tanh()
.matMul(W2).add(b2).tanh()
.matMul(W3).add(b3).asScalar()
)
const loss = (ypred, y) => ypred.sub(y).square().mean()
const optimizer = dl.train.adam();
async function train_on_batch(x, y){
const size_of_batch = x.shape[0];
for(let i=0; i<size_of_batch; i++){
optimizer.minimize(() => loss(
model(extractObs(x, i)),
extractObs(y, i)
));
}
}
async function predict(x){
const size_of_batch = x.shape[0];
const prediction = new Array(size_of_batch);
for(let i=0; i<size_of_batch; i++){
prediction[i] = (await model(extractObs(x, i)).data())[0]
}
return prediction
}
return {model, train_on_batch, predict, weights:{W1, b1, W2, b2, W3, b3}}
}