gradient_descent_widget = (f, fprime, x2, y2, thresholds, x2b, y2b) => {
const x_init = [1,1];
const x3 = d3.scaleLinear([-2, 2], [-3*padding, size]);
const y3 = d3.scaleLinear([-2, 2], [size, -2*padding]);
const color = d3.scaleSequential(d3.extent(thresholds), d3.interpolateMagma);
const svg = d3.create("svg")
.attr("viewBox", [0, 0, width, height/2+2*padding])
.style("display", "block")
.style("margin", "10 10");
const cell = svg.append("g")
.selectAll("g")
.data(d3.cross(d3.range(2), d3.range(1)))
.join("g")
.attr("transform", ([i, j]) => `translate(${i * size},${j * size})`);
cell.each(function([i, j]) {
let a = d3.select(this);
if (i == 0) {
a.append("g")
.attr("fill", "none")
.attr("stroke", "#fff")
.attr("stroke-opacity", 0.5)
.selectAll("path")
.data(generate_contours(f, x3, y3, thresholds, size-3*padding, size-2*padding))
.join("path")
.attr("fill", d => color(d.value))
.attr("d", d3.geoPath())
.attr("transform", `translate(${padding},${padding})`);
a.append("g")
.attr("transform", `translate(0,${y2(-2)})`)
.call(d3.axisBottom(x2).ticks(5,"f"))
.style("font", "14px sans-serif");
a.append("g")
.attr("transform", `translate(${x2(-2)},0)`)
.call(d3.axisLeft(y2).ticks(5, "f"))
.style("font", "14px sans-serif");
a.append("path")
.attr("fill", "none")
.attr("stroke", "#22ff22")
.attr("stroke-width", 7)
.attr("d", d3.line(d => x2(d.x), d => y2(d.y))(run_gradient_descent(f, fprime, x_init)));
}
if (i == 1) {
a.append("path")
.attr("fill", "none")
.attr("stroke", "#000000")
.attr("stroke-width", 3)
.attr("d", d3.line(d => x2b(d.step), d => y2b(d.gap))(run_gradient_descent(f, fprime, x_init)));
a.append("g")
.attr("transform", `translate(0,${y2b(1e-10)})`)
.call(d3.axisBottom(x2b).ticks(5,"f"))
.style("font", "14px sans-serif");
a.append("g")
.attr("transform", `translate(${x2b(-2)},0)`)
.call(d3.axisLeft(y2b).ticks(5, "f").tickFormat(d3.format(".1e")))
.style("font", "14px sans-serif");
a.append("text")
.attr("class", "x label")
.attr("text-anchor", "middle")
.attr("x", size/2)
.attr("y", size+20)
.text("iteration")
.style("font", "18px sans-serif");
}
});
svg.append("g")
svg.property("value", [])
return svg.node();
}