bert_plot = {
let svg = d3.create('svg').attr('width', width).attr('height', height)
let plot_g = svg.append('g').attr('transform', `translate(${margin},${margin})`)
plot_g.selectAll('text')
.data(reconstituted_tokens)
.enter()
.append('text')
.classed('token', true)
.text(d => d)
.attr('x', (d,i) => token_scale(i)).attr('text-anchor', 'middle')
.attr('y', -5)
.attr('font-size', '14px')
.attr('font-weight', 'bold')
.attr('fill', d3.hcl(0,0,50))
plot_g.selectAll('path')
.data(neuron_data)
.enter()
.append('path')
.classed('neuron', true)
.attr('d', d => neuron_line(d.values))
.attr('fill', 'none')
.attr('stroke', d3.hcl(210,20,40)).attr('stroke-width', .6).attr('stroke-opacity', .2)
plot_g.append('g')
.call(d3.axisLeft(neuron_scale).ticks(4)).call(axis_tweak)
let selected_word_ids = [], neuron_interval = []
function filter_neurons() {
plot_g.selectAll('.neuron')
.attr('stroke', d3.hcl(210,20,40)).attr('stroke-width', .6).attr('stroke-opacity', .2)
.filter(d => {
if(selected_word_ids.length == 0 || neuron_interval.length == 0)
return false
let neuron_values = d.values
for(let i = 0; i < selected_word_ids.length; i++) {
let neuron_value = neuron_values[selected_word_ids[i]].value
if(neuron_value < neuron_interval[0] || neuron_value > neuron_interval[1])
return false
}
return true
})
.attr('stroke', d3.hcl(330,50,30)).attr('stroke-width', 1.5).attr('stroke-opacity', .4).raise()
}
function handle_x_brush(e) {
let interval = d3.event.selection
selected_word_ids = []
plot_g.selectAll('.token').each((_,i) => {
if(token_scale(i) >= interval[0] && token_scale(i) <= interval[1])
selected_word_ids.push(i)
})
plot_g.selectAll('.token')
.attr('fill', d3.hcl(0,0,50))
.filter((_,i) => token_scale(i) >= interval[0] && token_scale(i) <= interval[1])
.attr('fill', d3.hcl(330,50,30))
filter_neurons()
}
let x_brush = d3.brushX().extent([[-5,-18],[plot_width,0]])
x_brush.on('brush', handle_x_brush)
plot_g.append('g').call(x_brush)
function handle_y_brush(e) {
let interval = d3.event.selection
neuron_interval[0] = neuron_scale.invert(interval[1])
neuron_interval[1] = neuron_scale.invert(interval[0])
filter_neurons()
}
let y_brush = d3.brushY().extent([[-14,0],[-2,plot_height]])
y_brush.on('brush', handle_y_brush)
plot_g.append('g').call(y_brush)
return svg.node()
}