function correlationMatrixWidget(data, options = {}) {
const { width = 600, height = 600, title = "Correlation Matrix" } = options;
const container = html`<div style="font-family: sans-serif; display: flex; flex-direction: column; gap: 1rem;"></div>`;
if (title) {
container.append(html`<h3>${title}</h3>`);
}
const keys = Object.keys(data[0] || {});
const numericKeys = keys.filter(key => data.every(d => !isNaN(+d[key])));
if (numericKeys.length < 2) {
container.append(html`<div style="color: red;">Dataset must have at least two numeric columns for a correlation matrix.</div>`);
return container;
}
const variableSelector = Inputs.select(numericKeys, {
label: "Select variables for correlation",
multiple: true,
value: numericKeys
});
container.append(variableSelector);
const plotContainer = html`<div></div>`;
container.append(plotContainer);
const exportContainer = html`<div style="display: flex; gap: 1rem;"></div>`;
container.append(exportContainer);
const exportCSVButton = html`<button>Export CSV</button>`;
const exportSVGButton = html`<button>Export SVG</button>`;
exportContainer.append(exportCSVButton, exportSVGButton);
function pearson(x, y) {
const n = x.length;
const sumX = d3.sum(x);
const sumY = d3.sum(y);
const sumXY = d3.sum(x.map((d, i) => d * y[i]));
const sumX2 = d3.sum(x.map(d => d * d));
const sumY2 = d3.sum(y.map(d => d * d));
const numerator = n * sumXY - sumX * sumY;
const denominator = Math.sqrt((n * sumX2 - sumX * sumX) * (n * sumY2 - sumY * sumY));
return denominator ? numerator / denominator : 0;
}
function computeCorrelationMatrix(vars) {
const matrix = vars.map(var1 => {
const row = { variable: var1 };
vars.forEach(var2 => {
const col1 = data.map(d => +d[var1]);
const col2 = data.map(d => +d[var2]);
row[var2] = pearson(col1, col2);
});
return row;
});
return { vars, matrix };
}
function matrixToCSV({vars, matrix}) {
const header = ["variable", ...vars];
const rows = matrix.map(row => header.map(col => row[col]));
return [header.join(","), ...rows.map(r => r.join(","))].join("\n");
}
function update() {
const selectedVars = variableSelector.value;
const correlationData = computeCorrelationMatrix(selectedVars);
const longData = [];
correlationData.matrix.forEach(row => {
selectedVars.forEach(col => {
longData.push({
x: row.variable,
y: col,
value: row[col]
});
});
});
const plotNode = Plot.plot({
marks: [
Plot.rect(longData, {
x: "x",
y: "y",
fill: "value",
title: d => `${d.x} vs ${d.y}: ${d.value.toFixed(2)}`,
rx: 2
}),
Plot.text(longData, {
x: "x",
y: "y",
text: d => d.value.toFixed(2),
fill: "black",
dy: "0.35em",
textAnchor: "middle"
})
],
color: {
scheme: "Blues",
domain: [-1, 1],
label: "Correlation",
tickCount: 10
},
x: { label: null },
y: { label: null },
width,
height,
marginLeft: 100,
marginBottom: 100
});
plotContainer.innerHTML = "";
plotContainer.append(plotNode);
exportCSVButton.onclick = () => {
const csvString = matrixToCSV(correlationData);
const blob = new Blob([csvString], { type: "text/csv" });
const url = URL.createObjectURL(blob);
const a = html`<a href="${url}" download="correlation-matrix.csv"></a>`;
a.click();
URL.revokeObjectURL(url);
};
exportSVGButton.onclick = () => {
const svg = plotNode.tagName === "svg" ? plotNode : plotNode.querySelector("svg");
if (svg) {
const serializer = new XMLSerializer();
let svgStr = serializer.serializeToString(svg);
svgStr = `<?xml version="1.0" standalone="no"?>\r\n` + svgStr;
const blob = new Blob([svgStr], { type: "image/svg+xml;charset=utf-8" });
const url = URL.createObjectURL(blob);
const a = html`<a href="${url}" download="correlation-matrix.svg"></a>`;
a.click();
URL.revokeObjectURL(url);
} else {
alert("SVG element not found in the plot.");
}
};
}
variableSelector.addEventListener("input", update);
update();
return container;
}