Public
Edited
Oct 5, 2023
Insert cell
Insert cell
tf = require('@tensorflow/tfjs@4.11.0')
Insert cell
cv2 = require('@tensorflow/tfjs-core@4.11.0')
Insert cell
viewof input = Inputs.text({label: "Input", value: "Hi", placeholder:"Input string"})
Insert cell
canvas = {
const width= 40;
const height= 40
const context = DOM.context2d(width, height);
context.fillStyle = "black";
context.fillRect(0, 0, width, height);
context.fillStyle = "white";
context.textBaseline = 'middle';
context.textAlign = 'center';

context.font = "37px sans-serif";
context.fillText(input, width/2, height/2);
// .measureText(textString ).width;
// context.fillText(input.value, 0, 0);
// context.canvas.style.background = "hsl(216deg 20% 90%)";
return context.canvas;
}

Insert cell
Insert cell
function pDist (x, y) {
// Expand the dimensions of x and y so that they have the same shape.
const xExpanded = tf.expandDims(x, 1);
const yExpanded = tf.expandDims(y, 0);

// Calculate the pairwise squared distances between x and y.
const squaredDistances = tf.square(tf.sub(xExpanded, yExpanded));
console.log('squaredDistances:', squaredDistances)
// Reduce the sum of the squared distances across the feature dimension.
const distances = tf.sum(squaredDistances, -1);
console.log('distances:', distances)

return distances;
// // // const dx = x[:, None, :] - y[None, :, :]
// // // return tf.reduce_sum(tf.square(dx), -1)

// // Expand the input tensors so that they have the same dimensions
// x = tf.expandDims(x, 1);
// y = tf.expandDims(y, 0);
// console.log('x:', x)
// console.log('y:', y)
// // Compute the difference between the input tensors
// const dx = tf.sub(x, y);
// console.log('DX:', dx)
// // // Compute the square of the difference
// const dxSquared = tf.square(dx);
// console.log('dxSquared:', dxSquared)
// // // Reduce the sum of the squared differences over the last axis
// // const pdistMatrix = tf.sum(dxSquared);
// const pdistMatrix = tf.square().sum( -1);
// // const pdistMatrix = tf.reduceSum(dxSquared, -1);

// // const x = tf.tensor([[1, 2], [3, 4]]);
// // const sum = tf.sum(x);
// // Return the pdist matrix
// return pdistMatrix;

// const dx = tf.sub(x.expandDims(1), y.expandDims(0));
// return tf.sum(tf.square(dx), -1);
}
Insert cell
async function sinkhornStep(C, f) {
// Compute the reduced log-sum-exp of -f - transpose(C) over the last axis
let axis = -1;
const g = tf.logSumExp(tf.concat([tf.neg(f), tf.transpose(C)], axis), axis);

// Update the f tensor
f = tf.logSumExp(tf.concat([tf.neg(g), C], axis), axis);

// Return the updated f and g tensors
return [f, g];
return [];
}
Insert cell
async function Sinkhorn(C, f = null, niter = 1000) {
// Get the number of elements in the cost matrix
const n = C.shape[0];

// Initialize the f tensor if it is not provided
if (f === null) {
f = tf.zeros([n], tf.float32);
}

// Perform the Sinkhorn iterations
let g;
for (let i = 0; i < niter; i++) {
// Perform a Sinkhorn step
[f, g] = await sinkhornStep(C, f);
}

// Compute the optimal transport matrix
const P = tf.exp(tf.expandDims(-f, 1) + tf.expandDims(-g, 0) + C).div(tf.cast(n, tf.float32));

// Return the optimal transport matrix, f, and g
return [P, f, g];
return []
}
Insert cell
VIDEO_SIZE = 512;
Insert cell
async function drawPoints(p) {
const w = VIDEO_SIZE;

// Create a black image
const img = tf.zeros([w, w, 3], tf.uint8);

// Set the shift parameter
const shift = 2;

// Scale and offset the points
const pScaled = tf.add(tf.mul(p, w * 0.9 * 0.5), w / 2);
// const pInt32 = tf.cast(pScaled, tf.int32);
const pInt32 = tf.cast(pScaled, 'int32');
console.log('pInt32:', pInt32);
// // Draw circles on the image
// for (let i = 0; i < pInt32.shape[0]; i++) {
// const x = pInt32.get(i, 0);
// const y = pInt32.get(i, 1);

// await cv2.circle(img, [x, y], 12, [255, 255, 255], -1, cv2.CV_AA, shift=shift);
// }

// Return the image
return img;
}
Insert cell
canvasEnd = {
// const width= 100;
const height= 400
const context = DOM.context2d(width, height);
context.fillStyle = "black";
context.fillRect(0, 0, width, height);
// context.fillStyle = "white";
// context.textBaseline = 'middle';
// context.textAlign = 'center';
// return generatePoints()
// const t = tf.linspace(0, 2 * Math.PI, 256);
// const x = tf.pow(tf.sin(t), 3) * 16;
// const y = tf.cos(t) * 13 - tf.cos(2 * t) * 5 - tf.cos(3 * t) * 2 - tf.cos(4 * t);
// // const pos0 = tf.concat([x.reshape([-1, 1]), y.reshape([-1, 1])], 1);
// return [x, y]
// }
// const pos0 = tf.concat([x, y]);

const pts = generatePoints('2');
const pts2 = generatePoints('7');

// const distances = pDist(pts, pts2)/(0.01)**2;
const distances = tf.pow(tf.div(pDist(pts, pts2),0.01), 2);
console.log('distances:', distances)
const [P, f, g] = await Sinkhorn(distances, 0, 20)
//const img = await drawPoints(pts);
// const img = drawPoints(pts);
context.fillStyle = "white";
const s = context.canvas.height*0.6
const dotSize = 3
pts.map(([x, y])=> {
// console.log(x,y)
// context.fillRect(s/2 + x*s,
// s/2 + y * s, dotSize, dotSize);
// context.fillRect(context.canvas.width/(2*window.devicePixelRatio) + x*s,
// context.canvas.height/(2*window.devicePixelRatio) + y * s, dotSize, dotSize);

context.beginPath();
context.arc(context.canvas.width/(2*window.devicePixelRatio) + x*s,
context.canvas.height/(2*window.devicePixelRatio) + y * s, dotSize, 0, 2 * Math.PI);
context.fill();
});
// return img;
// // Display the image
// const image = new Image();
// image.src = await img.toDataURL();
// image.onload = () => context.canvas.drawImage(image, 0, 0, VIDEO_SIZE, VIDEO_SIZE);

// // context.font = "70px sans-serif";
// // context.fillText(input, width/2, height/2);
// // // .measureText(textString ).width;
// // // context.fillText(input.value, 0, 0);
// // // context.canvas.style.background = "hsl(216deg 20% 90%)";
return context.canvas;

}
Insert cell

One platform to build and deploy the best data apps

Experiment and prototype by building visualizations in live JavaScript notebooks. Collaborate with your team and decide which concepts to build out.
Use Observable Framework to build data apps locally. Use data loaders to build in any language or library, including Python, SQL, and R.
Seamlessly deploy to Observable. Test before you ship, use automatic deploy-on-commit, and ensure your projects are always up-to-date.
Learn more