Published
Edited
Feb 11, 2020
1 fork
Importers
1 star
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
sinReward = s => tf.sin(s)
Insert cell
trainer = {
const runEnv = (policy, reward, particles) => {
let s = policy.sample(particles);
let deviation = s.sub(policy.mean); // the amount of deviation from the mean
// return [s, deviation, tf.tensor(s.dataSync().map(s => reward.js(s)))];
return [s, deviation, reward.tf(s)];
}

var trainer = async(runInfo, callback) => {
const optimizer = runInfo.optimizer(runInfo.learningRate);
const policy = runInfo.policy;
const reward = runInfo.reward;
const base = tf.variable(tf.tensor1d([0]));
const beta = tf.scalar(runInfo.baselineDecayRate); // how fast does the baseline change
for (let i = 0; i < runInfo.nbOptSteps; i++) {
var [states, rewards] = tf.tidy(() => {
/// sample actions from the environment which is just a simple 1D function in our case
/// ** note that the actions are taken w.r.t. current policy **
const [states, deviation, rewards] = runEnv(policy, runInfo.reward, runInfo.nbParticles);
/// train the policy
optimizer.minimize(
() => runInfo.loss.loss(policy, reward.tf, states, rewards, deviation, base),
false, /// returnCost? no, we're not using it
policy.getTFParams /// varList? only the policy parameters and not the baseline param(s)
);
if( runInfo.useBaseline ) {
/// train the baseline
base.assign(base.mul(beta.sub(1).neg()).add(rewards.mean().mul(beta)));
}
return [states, rewards];
});
/// train loss (if it has function approximators)
if( runInfo.loss.fit ) {
await runInfo.loss.fit(states, rewards);
}
/// a (sort of hacky) way to stop the training before it has finished (e.g. the stop button)
if( runInfo.runId != currentRunId.value )
break;

/// call the `callback` function if present: can be used to slow down training or step it manually
if( typeof(callback) == 'function' ) {
await callback(i, states, rewards, policy, base);
}
/// freeing-up memory
//states.dispose();
//rewards.dispose();
}
/// freeing-up memory
base.dispose();
beta.dispose();
}
return trainer
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell

One platform to build and deploy the best data apps

Experiment and prototype by building visualizations in live JavaScript notebooks. Collaborate with your team and decide which concepts to build out.
Use Observable Framework to build data apps locally. Use data loaders to build in any language or library, including Python, SQL, and R.
Seamlessly deploy to Observable. Test before you ship, use automatic deploy-on-commit, and ensure your projects are always up-to-date.
Learn more