Published
Edited
Aug 18, 2019
2 stars
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
function image_scales() {
return {
"x": d3.scaleLinear().domain([0, 28]).range([0, width / 4]),
"y": d3.scaleLinear().domain([0, 28]).range([0, height]),
"zoom": d3.scaleLinear().domain([0, 1]).range([0, height / 4]),
"fill": d3.scaleLinear().domain([-1, 1]).range(["white", "black"]),
"x_enc": d3.scaleLinear().domain([0, 256]).range([0, 3 * width / 4]),
"enc_fill": d3.scaleLinear().domain([0, 3]).range(["white", "#234253"]),
"y_hat": d3.scaleLinear().domain([0, 1]).range([height / 4, 0]),
"classes": d3.scaleLinear().domain([0, 10]).range([0, width / 5]),
"vtime": d3.scaleLinear().domain([0, 5]).range([height / 4, 0]),
"phi_time": d3.scaleLinear().domain([0, 2]).range([0, width / 8])
}
}
Insert cell
Insert cell
Math.round(0.31322 * 10) / 10
Insert cell
n_glimpses = 3
Insert cell
chart = {
const svg = d3.select(DOM.svg(width, height));
var scales = image_scales();
var states = forwards(f0, x, math.zeros(2), math.zeros(64), n_glimpses)
var encs = flatten_encs(states)
var phi = flatten_phi(states)
var locs = flatten_mu(states, [14, 14])
svg.selectAll("g")
.data(["image", "phi", "g", "h", "y_hat", "l", "mu_links", "mu"]).enter()
.append("g")
.attr("id", (d) => d);
translate_elems(svg, scales, states[0]["phi"][0].size());
initialize_view(svg, image, encs, phi, scales, locs, update_glimpse)
var old_pos = [0, 0];
function update_glimpse() {
var pos = d3.mouse(this),
pos = [scales.x.invert(pos[0]) / 14 - 1, scales.y.invert(pos[1]) / 14 - 1],
pos = [Math.round(pos[0] * 50) / 50, Math.round(pos[1] * 50) / 50]; // snap to grid
if (pos == old_pos) return;
old_pos = pos

var cur_time = d3.select(this).attr("time");
if (cur_time != -1) return;
states = forwards(f0, x, math.matrix(pos), math.zeros(64), n_glimpses)
encs = flatten_encs(states);
phi = flatten_phi(states);
locs = flatten_mu(states, denormalize(pos, 28))
update_phi(svg, scales, phi)
update_encs(svg, scales, encs)
update_locs(svg, scales, locs)
}
return svg.node()
}
Insert cell
function flatten_mu(x_list, mu0) {
var result = {"mu": [{"time": -1, "value": mu0}], "l": [{"time": -1, "value": mu0}]};
var keys = ["mu", "l"];
var vals;
for (var time = 0; time < x_list.length - 1; time++) {
for (var k in keys) {
vals = denormalize(x_list[time][keys[k]]._data, 28);
result[keys[k]].push({"time": time, "value": vals});
}
}
return result
}
Insert cell
function initialize_view(svg, image, encs, phi, scales, locs, update_fun) {
svg.select("#image")
.selectAll("rect")
.data(image).enter()
.append("rect")
.attrs({
"x": (d) => scales.x(d.i),
"y": (d) => scales.y(d.j),
"width": scales.x(1) - scales.x(0),
"height": scales.y(1) - scales.y(0),
"fill": (d) => scales.fill(d.value)
})

// Draw the phis from the current attended region
svg.select("#phi")
.selectAll("rect")
.data(phi).enter()
.append("rect")
.attrs({
"x": (d) => scales.x(d.i) / 2 + scales.phi_time(d.time),
"y": (d) => scales.y(d.j) / 2 + scales.zoom(d.zoom),
"width": scales.x(1) - scales.x(0),
"height": scales.y(1) - scales.y(0),
"fill": (d) => scales.fill(d.value)
});

// Draw the glimpse encodings
svg.select("#g")
.selectAll("rect")
.data(encs["g"]).enter()
.append("rect")
.attrs({
"x": (d) => scales.x_enc(d.i),
"y": (d) => scales.vtime(d.time),
"width": scales.x_enc(1) - scales.x_enc(0),
"height": scales.vtime(0) - scales.vtime(1),
"fill": (d) => scales.enc_fill(d.value)
});
// Draw the hidden states
svg.select("#h")
.selectAll("rect")
.data(encs["h"]).enter()
.append("rect")
.attrs({
"x": (d) => scales.x_enc(d.i),
"y": (d) => scales.vtime(d.time),
"width": scales.x_enc(1) - scales.x_enc(0),
"height": scales.vtime(0) - scales.vtime(1),
"fill": (d) => scales.enc_fill(d.value)
});
svg.select("#l")
.selectAll("circle")
.data(locs["l"], (d) => d.time).enter()
.append("circle")
.attrs({
"cx": (d) => scales.x(d.value[0]),
"cy": (d) => scales.y(d.value[1]),
"r": 2,
"fill": "red"
})
var drag = d3.drag().on("drag", update_fun);
svg.select("#mu")
.selectAll("circle")
.data(locs["mu"], (d) => d.time).enter()
.append("circle")
.attrs({
"cx": (d) => scales.x(d.value[0]),
"cy": (d) => scales.y(d.value[1]),
"r": 8,
"fill": "blue",
"time": (d) => d.time
})
.call(drag)

var mu_line = d3.line()
.x((d) => scales.x(d.value[0]))
.y((d) => scales.x(d.value[1]));

svg.select("#mu_links")
.selectAll("path")
.data([locs["mu"]]).enter()
.append("path")
.attrs({
"d": mu_line,
"fill": "none",
"stroke-width": 2,
"stroke": "blue"
});

// Draw the associated logits
/* svg.select("#y_hat")
.selectAll("rect")
.data(flat_states["logit"].filter((d) => d.time == 0)).enter()
.append("rect")
.attrs({
"x": (d) => scales.classes(d.i),
"y": (d) => scales.y_hat.range()[0] - scales.y_hat(d.value),
"width": scales.classes(1) - scales.classes(0),
"height": (d) => scales.y_hat(d.value),
"fill": "#B1434E"
});
*/
}
Insert cell
Insert cell
Insert cell
function update_locs(svg, scales, locs) {
// locs["l"].pop()
svg.select("#l")
.selectAll("circle")
.data(locs["l"], (d) => d.time)
.attrs({
"cx": (d) => scales.x(d.value[0]),
"cy": (d) => scales.y(d.value[1]),
});
// locs["mu"].pop()
svg.select("#mu")
.selectAll("circle")
.data(locs["mu"], (d) => d.time)
.attrs({
"cx": (d) => scales.x(d.value[0]),
"cy": (d) => scales.y(d.value[1]),
});
var mu_line = d3.line()
.x((d) => scales.x(d.value[0]))
.y((d) => scales.x(d.value[1]));

svg.select("#mu_links")
.selectAll("path")
.data([locs["mu"]])
.attr("d", mu_line);
}
Insert cell
function update_encs(svg, scales, flat_encs) {
svg.select("#g")
.selectAll("rect")
.data(flat_encs["g"])
.attrs({
"fill": (d) => scales.enc_fill(d.value)
});

svg.select("#h")
.selectAll("rect")
.data(flat_encs["h"])
.attrs({
"fill": (d) => scales.enc_fill(d.value)
});
// svg.select("#y_hat")
// .selectAll("rect")
// .data(flattened["logits"])
// .attrs({
// "y": (d) => scales.y_hat.range()[0] - scales.y_hat(d.value),
// "height": (d) => scales.y_hat(d.value)
// });
}
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
function recurrent_attention_gen(weights, hyper) {
var nets = {
"sensor": glimpse_gen(weights["glimpse"], hyper["h_g"], hyper["h_l"], hyper["g"], hyper["k"], hyper["s"], hyper["c"]),
"rnn": core_gen(weights["core"]),
"locator": location_gen(weights["locator"], hyper["std"]),
"classifier": action_gen(weights["classifier"])
}
function f(x, l_t_prev, h_t_prev) {
var sensor = nets["sensor"](x, l_t_prev);
var h_t = nets["rnn"](sensor["g"], h_t_prev);
var locs = nets["locator"](h_t);
var log_pi = lqnorm(locs["l"]._data[0], locs["mu"]._data[0], hyper["std"]) +
lqnorm(locs["l"]._data[1], locs["mu"]._data[1], hyper["std"]),
logits = nets["classifier"](h_t);

return {
"phi": sensor["phi"],
"g": sensor["g"],
"h": h_t,
"l": locs["l"],
"mu": locs["mu"],
"log_pi": log_pi,
"logits": logits
}
}
return f;
}
Insert cell
function forwards(f, x, l0, h0, n_glimpses) {
var state = f(x, l0, h0),
result = [state];
for (var i = 0; i < n_glimpses - 1; i++) {
state = f(x, state["l"], state["h"])
result.push(state)
}
return result;
}
Insert cell
Insert cell
Insert cell
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more