function graphDecisionTree(root, {
fontFamily = 'Arial,Helvetica Neue,Helvetica,sans-serif',
fontSize = 12,
panZoom = false
} = {}) {
const dx = 40
const dy = 120
const marginLeft = -dy + 2.5
const tree = d3.tree()
.nodeSize([dx, dy])
.separation( (a,b) => {
let sep = 1.25
const aVertPriority = a.id.replace(/[^yn]/g, '').replace(/y/g, 1).replace(/n/g, 0)
const bVertPriority = b.id.replace(/[^yn]/g, '').replace(/y/g, 1).replace(/n/g, 0)
if(aVertPriority > bVertPriority) sep = a.data.actions ? (a.data.actions.length + 1)/1.8 : 1
if(aVertPriority < bVertPriority) sep = b.data.actions ? (b.data.actions.length + 1)/1.8 : 1
return sep
})
root = tree(root)
let maxWidthByDepth = new Map()
root.descendants().forEach(d => {
if(d.depth < 1) {
d.textWidth = 0
return maxWidthByDepth.set(d.depth, 0)
}
const w = maxWidthByDepth.get(d.depth)
const statementText = d.data.statement ? d.data.statement : ''
const typeText = (d.data.type + ' ' + d.data.number)
console.log(d.data.statement?.split('>'))
const actionText = d.data.actions ? d.data.actions
.reduce( (a, b) => a.statement.length > b.statement.length ? a : b, {statement: ''} ).statement
: ''
const largestText = [statementText, typeText, actionText].reduce( (a,b) => a.length > b.length ? a : b, '' )
d.textWidth = elementSize( d3.create('svg:text').text(largestText).node(),
`0,0,${500},${500}`,
fontFamily).width
if(w){
if( d.textWidth > w ) {
maxWidthByDepth.set(d.depth, d.textWidth)
}
} else {
maxWidthByDepth.set(d.depth, d.textWidth)
}
})
//calc cummulative sums
maxWidthByDepth.forEach((value, key) => {
if(key < 2) return
const w = maxWidthByDepth.get(key-1) + value
maxWidthByDepth.set(key, w)
})
let x0 = Infinity
let x1 = -x0
let maxWidth = 0
root.descendants().forEach( d => {
if (d.x > x1) x1 = d.x;
if (d.x < x0) x0 = d.x;
if(d.depth < 1) return
d.y = d.y + maxWidthByDepth.get(d.depth-1)
maxWidth = d.y
})
const svg = d3.create('svg')
.attr('viewBox', [0, 0, maxWidth + 20, x1 - x0 + dx * 2])
const defs = svg.append('svg:defs')
defs.append('svg:marker')
.attr('id','arrow')
.attr('markerWidth','10')
.attr('markerHeight','10')
.attr('refX','10')
.attr('refY','4')
.attr('orient','auto')
.attr('markerUnits','userSpaceOnUse')
.append('path')
.attr('d', 'M0,0 L10,4 L0,8')
.attr('stroke', '#999')
.attr('fill', 'transparent')
defs.append('svg:marker')
.attr('id','arrow-selected')
.attr('markerWidth','10')
.attr('markerHeight','10')
.attr('refX','10')
.attr('refY','4')
.attr('orient','auto')
.attr('markerUnits','userSpaceOnUse')
.append('path')
.attr('d', 'M0,0 L10,4 L0,8')
.attr('stroke', 'green')
.attr('fill', 'transparent')
defs.append('svg:marker')
.attr('id','arrow-red')
.attr('markerWidth','12')
.attr('markerHeight','8')
.attr('refX','10')
.attr('refY','4')
.attr('orient','auto')
.attr('markerUnits','userSpaceOnUse')
.append('path')
.attr('d', 'M0,0 L10,4 L0,8')
.attr('stroke', 'green')
.attr('fill', 'transparent')
const g = svg.append('g')
.attr('font-family', fontFamily)
.attr('font-size', fontSize)
//.attr('transform', `translate(${marginLeft},${dx - x0})`)
let zoom = d3.zoom()
.on('zoom', handleZoom)
let first = true
function handleZoom(e) {
g.attr('transform', e.transform)
}
if(panZoom){
svg.call(zoom)
.call(zoom.translateBy, marginLeft, dx - x0)
} else {
g.attr('transform', `translate(${marginLeft},${dx - x0})`)
}
// links
const link = g.append('g')
.attr('fill', 'none')
.attr('stroke-width', 1.5)
.selectAll('path')
.data(root.links())
.join('path')
.attr('stroke', 'lightgray')
.attr('d', d => treeLink(d.source.textWidth + 20)(d) )
.style('visibility', d => d.source.depth ? 'visible' : 'hidden' )
const node = g.append('g')
.attr('stroke-linejoin', 'round')
.attr('stroke-width', 3)
.selectAll('g')
.data(root.descendants())
.join('g')
.attr('transform', d => `translate(${d.y},${d.x})`)
node.append('circle')
.attr('fill', d => d.children ? '#000' : 'salmon')
.attr('r', 2.5)
.style('opacity', d => !d.depth ? 0 : 1 )
// node text
const nodeLabel = node.append('text')
.attr('dx', 10)
nodeLabel.append('tspan')
.text(d => d.id !== '/' ? d.data.type + ' ' + d.data.number : '')
nodeLabel.append('tspan')
.attr('x', 10)
.attr('dy', '1.2em')
.text(function(d) {
return d.data.statement ? d.data.statement : ''
})
nodeLabel
.each( function(d) {
const nodeLabelNode = this
d.data.actions?.forEach(e => {
d3.select(nodeLabelNode).append('tspan')
.attr('x', 10)
.attr('dy', '1.2em')
.text(e.statement)
})
})
//re-adjust viewBox height
const box = elementSize(node.filter(':last-child').select('text').clone(true).node(), svg.attr('viewBox'), fontFamily)
const vb = svg.attr('viewBox').split(',')
vb[3] = parseInt(vb[3]) + box.height
svg.attr('viewBox', vb.toString())
node.append('rect')
.attr('width', function(d,i,a){
const box = elementSize( this.parentNode.querySelector('text').cloneNode(true), svg.attr('viewBox'), fontFamily)
return box.width + 10
} )
.attr('height', function(d){
const box = elementSize( this.parentNode.cloneNode(true), svg.node().attributes['viewBox'].nodeValue, fontFamily)
return box.height + 5
} )
.attr('fill', 'aliceblue')
// .attr('fill', d => {
// })
.attr('rx', '3px')
.style('visibility', d => !d.depth ? 'hidden' : 'visible' )
.attr('x', 5)
.attr('y', '-13.3')
.lower()
// return to condition
root.each( e => {
if(e.data.returnToCondition){
const filt = root.descendants().filter( f => f.data.type === 'condition'
&& f.data.number === e.data.returnToCondition)
// if the same number condition occurs more than once find shortest path
let toPosition = filt[0]
if(filt.length > 1){
const lengths = filt
.map( f => e.path(f).length )
let idx = lengths.indexOf(Math.min(...lengths))
toPosition = filt[idx]
}
let hoveredId = ''
const mid1 = e.id.substr(-3,1) === 'n' ?
e.x < toPosition.x && Math.abs(toPosition.x - e.x) < 45 ? [e.y, (e.x + toPosition.x )/2 ] :
[e.y, e.x + e.data.actions.length*15]
: e.x > toPosition.x ? [e.y, (e.x + toPosition.x )/2 ]
: [e.y, e.x - 15]
const mid2 = e.id.substr(-3,1) === 'n' ?
e.x < toPosition.x && Math.abs(toPosition.x - e.x) < 45 ? [toPosition?.y, (e.x + toPosition.x )/2 ]
: [toPosition?.y+1, e.x + e.data.actions.length*15]
: e.x > toPosition.x ? [toPosition?.y, (e.x + toPosition.x )/2 ]
: [toPosition?.y+1, e.x - 15]
const line = Math.abs(e.x - toPosition?.x) < 25 ?
d3.line()( [[e.y,e.x], mid1, [toPosition?.y + toPosition.textWidth + 20, mid1[1] ]] )
: e.y === toPosition?.y ? d3.line()( [[e.y,e.x], [toPosition?.y , toPosition?.x]] )
: d3.line()( [[e.y,e.x], mid1, mid2, [toPosition?.y+1, toPosition?.x]] )
const returnLineContainer = g.append('g')
.attr('fill', 'none')
const returnLine = returnLineContainer
.append('path')
.attr('stroke-width', 0.75)
.attr('stroke-dasharray', '5,2')
.attr('marker-end', 'url(#arrow)' )
.attr('d', line)
returnLineContainer
.append('path')
.attr('d', line)
.attr('stroke', 'black')
.attr('stroke-width', 10)
.attr('stroke-opacity', '0%')
.on('mouseover', function() { update(true) })
.on('mouseout', function() { update(false) })
returnLineContainer
.append('title')
.text('return to condition ' + e.data.returnToCondition)
function update(hover){
returnLine
.attr('stroke',
hover ? 'green'
//: trace[trace.length-2]?.id === e.id && trace[trace.length-1]?.id === toPosition.id ? 'red'
: '#999'
)
.attr('stroke-width', hover ? 1.25 : 0.75)
.attr('marker-end', hover ? 'url(#arrow-selected)'
//: trace[trace.length-2]?.id === e.id && trace[trace.length-1]?.id === toPosition.id ? 'url(#arrow-red)'
: 'url(#arrow)'
)
//if(hover || trace[trace.length-2]?.id === e.id && trace[trace.length-1]?.id === toPosition.id ) {
if(hover) {
returnLineContainer.raise()
} else {
returnLineContainer.lower()
}
}
update(false)
}
})
// Y/N labels
const label = g.append('g')
.selectAll('g')
.data(root.links())
.join('g')
.style('visibility', d => !d.source.depth ? 'hidden' : 'visible' )
.attr('transform', d => {
const midy = 0.3 * ( d.source.y + d.source.textWidth ) + 0.7 * d.target.y
const midx = 0.3 * d.source.x + 0.7 * d.target.x
return `translate(${midy},${midx})`})
label.append('text')
.attr('class', 'link-label')
.text(d => {
const numDigits = Math.log10(d.target.data.number) + 1 |0
return d.target.id.substr(-numDigits - 2,1) === 'y' ? 'Y' : 'N'
})
.attr('dy', d => {
const numDigits = Math.log10(d.target.data.number) + 1 |0
return d.target.id.substr(-numDigits - 2,1) === 'n' ? '1.6em' : '-0.8em'
})
.attr('dx', '-0.8em')
d3.select(node.node().parentNode).raise()
return {tree: svg.node(), d3Vars: {svg, node, link} }
}