Published
Edited
Feb 9, 2018
2 stars
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
setup_rbm = (input_size, hidden_size) => {
const visible_nodes = init_layer(input_size);
const hidden_nodes = init_layer(hidden_size);
const weights = init_edge_weights(input_size, hidden_size);
return {weights, visible_nodes, hidden_nodes}
}
Insert cell
Insert cell
Insert cell
Insert cell
// create an array of size of the layer we are calculating for
// and map over to find energies/ calc activations.
function calc_energy_acts(rbm, dir, input){
const destination = dir==='in'? 'hidden_nodes': 'visible_nodes';
const old_layer = rbm[destination];
return old_layer.map((d,i) => {
// get weights as a plain array
const curr_weights = rbm.weights
.filter(w => (dir==='in'? w.hidden: w.visible) === i)
.map(w => w.val);
// get dot product of weights by input with bias term at and (where it is in weights)
const energy = dot_prod(curr_weights, [...input, 1])
const p_act = logistic(energy);
const activation = bernouli(p_act)
return {energy, p_act, activation}
})
}

Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
{
// extract activations as plain vector
const hidden_acts = updated_rbm.hidden_nodes.map(d => d.activation)
// get the new visible nodes based on mapping function
const reconstructed_visible = calc_energy_acts(updated_rbm, 'out', hidden_acts);
// reassemble RBM with new visible nodes
const reconstructed_rbm = Object.assign({}, updated_rbm, {visible_nodes: reconstructed_visible})
// plot it
return plot_rbm({rbm: reconstructed_rbm, DOM})
}
Insert cell
Insert cell
pass_observation = (obs, rbm) => {
// calculate hidden energies/ activations
const new_hidden = calc_energy_acts(rbm, 'in', obs);
// isolate just activations to serve as a new input
const hidden_acts = new_hidden.map(d => d.activation)
// calculate new visible energies/activations
const new_visible = calc_energy_acts(rbm, 'out', hidden_acts)
return {input: obs, hidden: new_hidden, visible: new_visible}
}
Insert cell
calc_agreements = (weights, pass_result) => {
const {input, hidden, visible} = pass_result;
// take pass result output and calculate the edge aggreements
return weights.map(d => {
const input_on = d.visible === 'bias' ? 1: input[d.visible];
const hidden_on = d.hidden === 'bias' ? 1: hidden[d.hidden].activation;
const visible_on = d.visible === 'bias' ? 1: visible[d.visible].activation;

const encode_agreement = input_on*hidden_on;
const decode_agreement = hidden_on*visible_on;
return {...d, encode_agreement, decode_agreement}
})
}
Insert cell
update_weights = (weight_agreements) => weight_agreements.map(w => {
const updated_weight = w.val + learning_rate*(w.encode_agreement - w.decode_agreement)
return {visible: w.visible, hidden: w.hidden, old: w.val, val: updated_weight}
})
Insert cell
Insert cell
function prediction_error(input, visible_layer){
return visible_layer
.reduce((sum, d, i) => sum + square(d.activation - input[i]),0)
}
Insert cell
Insert cell
function pass_and_update(rbm, input){
// pass input through the current network and get back activation results
const pass_result = pass_observation(input, rbm);
// calculate reconstruction error.
const error = prediction_error(input, pass_result.visible);
// for each edge calculate if the endpoints matched eachother.
const weight_agreements = calc_agreements(rbm.weights, pass_result, input);
// update weights
const new_weights = update_weights(weight_agreements);
return {new_weights, error}
}
Insert cell
Insert cell
training_errors_rbm = {
const input_size = sample_data[0].length,
num_obs = sample_data.length,
errors = [];
let rbm = setup_rbm(input_size, number_hidden_nodes),
epoch = 0,
epoch_error,
input,
weights_and_error;
while (epoch < epochs){
epoch_error = 0;
// feed all training examples through and update at each minibatch
for(let i=0; i<num_obs; i++){
// grab training data obs
input = sample_data[i];
// pass input through boltzmann machine and get error/new weights
weights_and_error = pass_and_update(rbm, input);
// add our error to the epochs error total
epoch_error += weights_and_error.error/num_obs;
//update the weights with new weights
rbm.weights = weights_and_error.new_weights;
}
// append this epochs average error to the array keeping track
errors.push(epoch_error);
// increment the epochs forward
epoch++
yield {errors,rbm}
}
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more