model_predictions = {
++modelTracker.count
if (!inputData) return {preds:[], losses: []}
const numObs = inputData.inputs.shape[0];
const w1_init_var = 2/(2 + hiddenNodes);
const W1 = dl.variable(dl.randomNormal([2,hiddenNodes], 0, Math.sqrt(w1_init_var)));
const b1 = dl.variable(dl.zeros([1,hiddenNodes]));
const w2_init_var = 2/(hiddenNodes + 2);
const W2 = dl.variable(dl.randomNormal([hiddenNodes,1], 0, Math.sqrt(w2_init_var)));
const b2 = dl.variable(dl.zeros([1,1]));
let hidden_act_func;
switch(activation_type){
case 'relu':
hidden_act_func = dl.relu;
break;
case 'sigmoid':
hidden_act_func = dl.sigmoid;
break;
case 'tanh':
hidden_act_func = dl.tanh;
break;
}
const classifier = xs => dl.tidy(() => {
const first_layer = hidden_act_func(xs.matMul(W1).add(b1))
const second_layer = first_layer.matMul(W2).add(b2).sigmoid()
return second_layer.asScalar()
})
const one = dl.scalar(1);
const negativeOne = dl.scalar(-1);
const loss = (prediction, truth) => dl.tidy(() => truth.mul(dl.log(prediction))
.add(
one.sub(truth).mul(dl.log(one.sub(prediction)))
).mul(negativeOne))
const optimizer = dl.train.adam();
async function train(inputs, classes){
return dl.tidy(()=>{
let loss_sum = dl.scalar(0);
for(let i=0; i<numObs; i++){
loss_sum = loss_sum.add(
optimizer.minimize(() => loss(
classifier(twoDSlice(inputs, i)),
oneDSlice(classes, i)
).asScalar(), true)
)
}
return loss_sum.data()
})
}
const losses = [],
runNumAtStart = modelTracker.count,
modelOutput = new Array(gridRes*gridRes);
let currentPred, currentLoss, xInp, yInp;
for (let iter = 0; iter < numberEpochs; iter++) {
if(runNumAtStart !== modelTracker.count) break;
currentLoss = await train(inputData.inputs, inputData.groups);
losses.push( currentLoss[0]);
await dl.nextFrame();
for(let i=0; i<gridRes; i++){
xInp = xside[i];
for(let j=0; j<gridRes; j++){
yInp = yside[j];
currentPred = await dl.tidy(() => classifier(dl.tensor([xInp, yInp], [1,2])).data())
modelOutput[j*gridRes + i] = currentPred[0]
}
}
yield {preds:modelOutput, losses}
}
W1.dispose(); b1.dispose();
W2.dispose(); b2.dispose();
one.dispose(); negativeOne.dispose();
}