Published
Edited
Jul 21, 2019
Insert cell
md`# RNN Update in 2D V2`
Insert cell
function generate_data(w, u, x) {
var f = rnn_gen(u, w, math.zeros(2)),
seq = d3.range(-1, 1, .195),
h = {"pre": [], "post": []};

for (var i = 0; i < seq.length; i++) {
for (var j = 0; j < seq.length; j++) {
var cur_h0 = math.matrix([seq[i], seq[j]]);
h["pre"].push(cur_h0);
h["post"].push(f(x, cur_h0))
}
}

var links = [],
points = [];
for (var i = 0; i < h["pre"].length; i++) {
links.push({"pre": h["pre"][i], "post": h["post"][i]});
points.push({"h": h["pre"][i], "type": "pre"});
points.push({"h": h["post"][i], "type": "post"});
}
return {"links": links, "points": points}
}
Insert cell
chart = {
const svg = d3.select(DOM.svg(width, height));
var x = math.zeros(2);
draw_axes(svg, scales);
svg.selectAll("g")
.data(["links", "points", "weights_select", "weights", "weights_text", "x_input"], (d) => d).enter()
.append("g")
.attr("id", (d) => d);
var u = math.zeros(2, 2),
w = math.zeros(2, 2),
all_data = generate_data(u, w, x);
// draw the links
svg.select("#links")
.selectAll("path")
.data(all_data["links"]).enter()
.append("path")
.attrs({
"d": (d) => "M" + scales.h0(d["pre"]._data[0]) + "," + scales.h1(d["pre"]._data[1]) + "L" + scales.h0(d["post"]._data[0]) + "," + scales.h1(d["post"]._data[1]),
"stroke": "#a4a4a4",
"stroke-width": 0.8
});
svg.select("#points")
.selectAll("circle")
.data(all_data["points"]).enter()
.append("circle")
.attrs({
"cx": (d) => scales.h0(d.h._data[0]),
"cy": (d) => scales.h1(d.h._data[1]),
"r": 3,
"fill-opacity": .6,
"fill": (d) => scales.type(d.type)
});
svg.select("#x_input")
.selectAll("circle")
.data([x]).enter()
.append("circle")
.attrs({
"cx": (d) => scales.x0(d._data[0]),
"cy": (d) => scales.x1(d._data[1]),
"r": 10,
"fill": "black"
})
.call(d3.drag().on("drag", update_x))
svg.select("#weights_select")
.selectAll("circle")
.data(rows(w, u)).enter()
.append("circle")
.attrs({
"cx": (d) => scales.wselect_x(d.value._data[0][0]) + scales.wtype(d.weight),
"cy": (d) => scales.wselect_y(d.value._data[0][1]),
"r": 7,
"fill": (d) => scales.weight_fill(d.weight)
})
.call(d3.drag().on("drag", (d) => update_weights(d)));
svg.select("#weights")
.selectAll("rect")
.data(entries(w, u)).enter()
.append("rect")
.attrs({
"x": (d) => scales.wtype(d.weight) + scales.col_offset(d.col),
"y": (d) => scales.row_offset(d.row),
"width": scales.col_offset(1) - scales.col_offset(0),
"height": scales.row_offset(1) - scales.row_offset(0),
"fill": (d) => scales.weight_fill(d.weight),
"fill-opacity": (d) => .5 * scales.weight_opacity(Math.abs(d.value)),
});
svg.select("#weights_text")
.selectAll("text")
.data(entries(w, u)).enter()
.append("text")
.attrs({
"x": (d) => scales.wtype(d.weight) + scales.col_offset(d.col),
"y": (d) => scales.row_offset(d.row),
"fill": (d) => scales.weight_fill(d.weight),
"fill-opacity": (d) => scales.weight_fill(Math.abs(d.value)),
"font-size": 15,
"alignment-baseline": "hanging"
})
.text((d) => Number(d.value).toFixed(3))

function update_x() {
if (d3.event.x > width || d3.event.x < scales.x0.range()[0] + 10) {
return;
}
if (d3.event.y > scales.x1.range()[0] - 10 || d3.event.y < 0) {
return;
}
x = math.matrix(
[scales.x0.invert(d3.event.x),
scales.x1.invert(d3.event.y)]);
all_data = generate_data(u, w, x);
update_figure(svg, all_data["points"], all_data["links"], x, w, u)
}
function update_weights(d) {
var cur = [scales.wselect_x.invert(d3.event.x - scales.wtype(d.weight)),
scales.wselect_y.invert(d3.event.y)];
if (cur[0] < scales.wselect_x.domain()[0] || cur[0] > scales.wselect_x.domain()[1]) {
return;
}
if (cur[1] < scales.wselect_y.domain()[0] || cur[1] > scales.wselect_y.domain()[1]) {
return;
}

if (d.weight == "w") {
w.subset(math.index(d.row, [0, 1]), cur);
} else {
u.subset(math.index(d.row, [0, 1]), cur);
}
all_data = generate_data(u, w, x);
update_figure(svg, all_data["points"], all_data["links"], x, w, u)
}

return svg.node()
}
Insert cell
function draw_axes(svg, scales) {
svg.append("g")
.attr("transform", "translate(" + scales.h0(0) + ",0)")
.call(d3.axisLeft().scale(scales.h1))

svg.append("g")
.attr("transform", "translate(0," + scales.h1(0) + ")")
.call(d3.axisBottom().scale(scales.h0))
svg.append("g")
.attr("transform", "translate(" + scales.x0(0) + ",0)")
.call(d3.axisLeft().scale(scales.x1).ticks(4))
svg.append("g")
.attr("transform", "translate(0," + scales.x1(0) + ")")
.call(d3.axisBottom().scale(scales.x0).ticks(4))
svg.append("g")
.attr("transform", "translate(" + (scales.wtype("u") + scales.wselect_x(0)) + ",0)")
.call(d3.axisLeft().scale(scales.wselect_y).ticks(4))
svg.append("g")
.attr("transform", "translate(" + scales.wtype("u") + "," + scales.wselect_y(0) + ")")
.call(d3.axisBottom().scale(scales.wselect_x).ticks(4))
svg.append("g")
.attr("transform", "translate(" + (scales.wtype("w") + scales.wselect_x(0)) + ",0)")
.call(d3.axisLeft().scale(scales.wselect_y).ticks(4))
svg.append("g")
.attr("transform", "translate(" + scales.wtype("w") + "," + scales.wselect_y(0) + ")")
.call(d3.axisBottom().scale(scales.wselect_x).ticks(4))
}
Insert cell
scales = generate_scales();
Insert cell
function rows(w, u) {
var weight_rows = [];
for (var i = 0; i < w._size[0]; i++) {
weight_rows.push({"value": math.row(w, i), "weight": "w", "row": i})
weight_rows.push({"value": math.row(u, i), "weight": "u", "row": i})
}
return weight_rows;
}
Insert cell
function entries(w, u) {
var weight_entries = [];
for (var i = 0; i < w._size[0]; i++) {
for (var j = 0; j < w._size[1]; j++) {
weight_entries.push({"value": w._data[i][j], "row": i, "col": j, "weight": "w"});
weight_entries.push({"value": u._data[i][j], "row": i, "col": j, "weight": "u"});
}
}
return weight_entries
}
Insert cell
function update_figure(svg, points, links, x, w, u) {
svg.select("#links")
.selectAll("path")
.data(links)
.attrs({
"d": (d) => "M" + scales.h0(d["pre"]._data[0]) + "," + scales.h1(d["pre"]._data[1]) + "L" + scales.h0(d["post"]._data[0]) + "," + scales.h1(d["post"]._data[1]),
});
svg.select("#points")
.selectAll("circle")
.data(points)
.attrs({
"cx": (d) => scales.h0(d.h._data[0]),
"cy": (d) => scales.h1(d.h._data[1]),
});
svg.select("#x_input")
.select("circle")
.data([x])
.attrs({
"cx": (d) => scales.x0(d._data[0]),
"cy": (d) => scales.x1(d._data[1]),
})

svg.select("#weights_select")
.selectAll("circle")
.data(rows(w, u))
.attrs({
"cx": (d) => scales.wselect_x(d.value._data[0][0]) + scales.wtype(d.weight),
"cy": (d) => scales.wselect_y(d.value._data[0][1]),
});
svg.select("#weights")
.selectAll("rect")
.data(entries(w, u))
.attrs({
"fill-opacity": (d) => .5 * scales.weight_opacity(Math.abs(d.value)),
});

svg.select("#weights_text")
.selectAll("text")
.data(entries(w, u))
.attrs({
"fill-opacity": (d) => scales.weight_opacity(Math.abs(d.value))
})
.text((d) => Number(d.value).toFixed(3))
}
Insert cell
function rnn_gen(w, u, b) {
function rnn(x, h) {
var mixture = math.add(math.multiply(w, x), math.multiply(u, h), b);
return (mixture).map(Math.tanh)
}
return rnn
}
Insert cell
w = math.zeros(2, 2)
Insert cell
function generate_scales() {
return {
"h0": d3.scaleLinear().domain([-1, 1]).range([0, .95 * 3 * width / 4]),
"h1": d3.scaleLinear().domain([-1, 1]).range([height, 0]),
"x0": d3.scaleLinear().domain([-1.2, 1.2]).range([3 * width / 4, width]),
"x1": d3.scaleLinear().domain([-1.2, 1.2]).range([height / 2, 0]),
"type": d3.scaleOrdinal().domain(["pre", "post"]).range(["#ff9966 ", "#9955bb"]),
"wselect_x": d3.scaleLinear().domain([-3, 3]).range([0, width / 8]),
"wselect_y": d3.scaleLinear().domain([-3, 3]).range([.95 * 3 * height / 4, 1.05 * height / 2]),
"weight_opacity": d3.scaleLinear().domain([0, 3]).range([0.4, 1]),
"row_offset": d3.scaleOrdinal().domain([0, 1]).range([3 * height / 4, 7 * height / 8]),
"col_offset": d3.scaleOrdinal().domain([0, 1]).range([0, width / 16]),
"weight_fill": d3.scaleOrdinal().domain(["w", "u"]).range(["#48994f", "#4f4899"]),
"wtype": d3.scaleOrdinal().domain(["w", "u"]).range([3 * width / 4, 7 * width / 8])
};
}
Insert cell
Insert cell
d3 = require("d3-selection", "d3-selection-multi", "d3-scale", "d3-array", "d3-drag", "d3-dispatch", "d3-axis")
Insert cell
math = require("https://cdnjs.cloudflare.com/ajax/libs/mathjs/6.0.3/math.min.js")
Insert cell

Purpose-built for displays of data

Observable is your go-to platform for exploring data and creating expressive data visualizations. Use reactive JavaScript notebooks for prototyping and a collaborative canvas for visual data exploration and dashboard creation.
Learn more