Public
Edited
Jul 8, 2024
Paused
2 forks
Importers
24 stars
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
Insert cell
click
Insert cell
Insert cell
Insert cell
drawMasks([onnxMaskToImage(output.data, output.dims[2], output.dims[3], useGrayScale ? color : undefined )], 500)
Insert cell
Insert cell
Insert cell
Insert cell
svg = {
const imageData = arrayToImageData(
output.data,
output.dims[2],
output.dims[3],
undefined
);
const options = {
numberofcolors: 2,
strokewidth: 2,
linefilter: true
};
const svg = ImageTracer.imagedataToSVG(imageData, options);
const target = html`${svg}`;
yield target;

d3.select(target).selectAll("path").attr("fill", "none");
}
Insert cell
color = {
const c = d3.scaleLinear().domain([-30, 30]).range([0, 255]);
return (d) => {
const g = c(d);
return {r:g, g:g, b:g, a:255}
}
}
Insert cell
update = {
// Replace the mask everything the click changes
click.maskEle.innerHTML = "";
click.maskEle.appendChild(onnxMaskToImage(output.data, output.dims[2], output.dims[3]));
}
Insert cell
model = ort.InferenceSession.create(modelFile)
Insert cell
feed = getModelParams({
clicks: [click],
tensor: tensorEmbedding,
modelScale
})
Insert cell
results = model.run(feed)
Insert cell
output = results[model.outputNames[0]]
Insert cell
tensorEmbedding = loadNpyTensor(embedding, "float32")
Insert cell
modelScale = handleImageScale(img)
Insert cell
img = customImage && customImage.image() || FileAttachment("dogs.jpg").image()
Insert cell
modelFile = FileAttachment("sam_onnx_quantized_example.onnx").url()
Insert cell
embedding = customImageEmbedding && customImageEmbedding.url() || FileAttachment("image_embedding.npy").url()
Insert cell
Insert cell
ort = {
const ort = await import('https://cdn.skypack.dev/onnxruntime-web@1.18?min');
// await import(
// "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/esm/ort.min.js"
// );
// const ort = await require("https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js");

// https://github.com/microsoft/onnxruntime/issues/13933
ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/";

return ort;
}
Insert cell
// https://github.com/aplbrain/npyjs

// Apache 2.0 License
class npyjs {
constructor(opts) {
if (opts) {
console.error([
"No arguments accepted to npyjs constructor.",
"For usage, go to https://github.com/jhuapl-boss/npyjs."
].join(" "));
}

this.dtypes = {
"<u1": {
name: "uint8",
size: 8,
arrayConstructor: Uint8Array,
},
"|u1": {
name: "uint8",
size: 8,
arrayConstructor: Uint8Array,
},
"<u2": {
name: "uint16",
size: 16,
arrayConstructor: Uint16Array,
},
"|i1": {
name: "int8",
size: 8,
arrayConstructor: Int8Array,
},
"<i2": {
name: "int16",
size: 16,
arrayConstructor: Int16Array,
},
"<u4": {
name: "uint32",
size: 32,
arrayConstructor: Int32Array,
},
"<i4": {
name: "int32",
size: 32,
arrayConstructor: Int32Array,
},
"<u8": {
name: "uint64",
size: 64,
arrayConstructor: BigUint64Array,
},
"<i8": {
name: "int64",
size: 64,
arrayConstructor: BigInt64Array,
},
"<f4": {
name: "float32",
size: 32,
arrayConstructor: Float32Array
},
"<f8": {
name: "float64",
size: 64,
arrayConstructor: Float64Array
},
};
}

parse(arrayBufferContents) {
// const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
const offsetBytes = 10 + headerLength;

const hcontents = new TextDecoder("utf-8").decode(
new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
);
const header = JSON.parse(
hcontents
.toLowerCase() // True -> true
.replace(/'/g, '"')
.replace("(", "[")
.replace(/,*\),*/g, "]")
);
const shape = header.shape;
const dtype = this.dtypes[header.descr];
const nums = new dtype["arrayConstructor"](
arrayBufferContents,
offsetBytes
);
return {
dtype: dtype.name,
data: nums,
shape,
fortranOrder: header.fortran_order
};
}

async load(filename, callback, fetchArgs) {
/*
Loads an array from a stream of bytes.
*/
fetchArgs = fetchArgs || {};
const resp = await fetch(filename, { ...fetchArgs });
const arrayBuf = await resp.arrayBuffer();
const result = this.parse(arrayBuf);
if (callback) {
return callback(result);
}
return result;
}
}
Insert cell
_ = require("underscore")
Insert cell
Insert cell
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.

// Apache-2.0 license

// Helper function for handling image scaling needed for SAM
handleImageScale = (image) => {
// Input images to SAM must be resized so the longest side is 1024
const LONG_SIDE_LENGTH = 1024;
let w = image.naturalWidth;
let h = image.naturalHeight;
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
return { height: h, width: w, samScale };
};
Insert cell

// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.

// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

loadNpyTensor = async (tensorFile, dType) => {
let npLoader = new npyjs();
const npArray = await npLoader.load(tensorFile);
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
return tensor;
}
Insert cell
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.

// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

// Modified: John Alexis Guerra Gomez

getModelParams = ({ clicks, tensor, modelScale }) => {
const imageEmbedding = tensor;
let pointCoords;
let pointLabels;
let pointCoordsTensor;
let pointLabelsTensor;

// Check there are input click prompts
if (clicks) {
let n = clicks.length;

// If there is no box input, a single padding point with
// label -1 and coordinates (0.0, 0.0) should be concatenated
// so initialize the array to support (n + 1) points.
pointCoords = new Float32Array(2 * (n + 1));
pointLabels = new Float32Array(n + 1);

// Add clicks and scale to what SAM expects
for (let i = 0; i < n; i++) {
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
pointLabels[i] = clicks[i].clickType;
}

// Add in the extra point/label when only clicks and no box
// The extra point is at (0, 0) with label -1
pointCoords[2 * n] = 0.0;
pointCoords[2 * n + 1] = 0.0;
pointLabels[n] = -1.0;

// Create the tensor
pointCoordsTensor = new ort.Tensor("float32", pointCoords, [1, n + 1, 2]);
pointLabelsTensor = new ort.Tensor("float32", pointLabels, [1, n + 1]);
}
const imageSizeTensor = new ort.Tensor("float32", [
modelScale.height,
modelScale.width,
]);

if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
return;

// There is no previous mask, so default to an empty tensor
const maskInput = new ort.Tensor(
"float32",
new Float32Array(256 * 256),
[1, 1, 256, 256]
);
// There is no previous mask, so default to 0
const hasMaskInput = new ort.Tensor("float32", [0]);

return {
image_embeddings: imageEmbedding,
point_coords: pointCoordsTensor,
point_labels: pointLabelsTensor,
orig_im_size: imageSizeTensor,
mask_input: maskInput,
has_mask_input: hasMaskInput,
// return_single_mask: false
};
};
Insert cell
onnxMaskToImage = {
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.

// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

// Convert the onnx model mask output to an HTMLImageElement
return function onnxMaskToImage(input, width, height, color) {
return imageDataToImage(arrayToImageData(input, width, height, color));
};
}
Insert cell
// Canvas elements can be created from ImageData
function imageDataToCanvas(imageData) {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
canvas.width = imageData.width;
canvas.height = imageData.height;
ctx?.putImageData(imageData, 0, 0);
return canvas;
}
Insert cell
// Use a Canvas element to produce an image from ImageData
function imageDataToImage(imageData) {
const canvas = imageDataToCanvas(imageData);
const image = new Image();
image.src = canvas.toDataURL();
return image;
}
Insert cell
// Convert the onnx model mask prediction to ImageData
function arrayToImageData(
input,
width,
height,
color = () => hexToRgbA(maskColor)
) {
const arr = new Uint8ClampedArray(4 * width * height).fill(0);
for (let i = 0; i < input.length; i++) {
// Threshold the onnx model mask prediction at 0.0
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
// in python
if (input[i] > maskThreshold) {
const { r, g, b, a } = color(input[i]);
arr[4 * i + 0] = r;
arr[4 * i + 1] = g;
arr[4 * i + 2] = b;
arr[4 * i + 3] = a;
}
}
return new ImageData(arr, height, width);
}
Insert cell
// Adapted from https://stackoverflow.com/questions/21646738/convert-hex-to-rgba

// Returns {r,g,b,a}
function hexToRgbA(hex) {
let c;
if (/^#([A-Fa-f0-9]{3,4}){1,2}$/.test(hex)) {
c = hex.substring(1).split("");
if (c.length == 4) {
c = [c[0], c[0], c[1], c[1], c[2], c[2], c[3], c[3]];
} else if (c.length == 3) {
c = [ "f", "f", c[0], c[0], c[1], c[1], c[2], c[2]];
} else if (c.length == 6) {
c = [ c[0], c[1], c[2], c[3], c[4], c[5], "f", "f"];
}
c = "0x" + c.join("");
return {
r: (c >> 24) & 255,
g: (c >> 16) & 255,
b: (c >> 8) & 255,
a: c & 255
};
}
throw new Error("Bad Hex");
}
Insert cell
drawMasks = (masks, width = 50) => htl.html`${masks.map(drawOneMask(width))}`
Insert cell
drawOneMask = width => d => Object.assign(d.img || d, {width})
Insert cell
ImageTracer = require("imagetracerjs@1.2.6")
Insert cell

One platform to build and deploy the best data apps

Experiment and prototype by building visualizations in live JavaScript notebooks. Collaborate with your team and decide which concepts to build out.
Use Observable Framework to build data apps locally. Use data loaders to build in any language or library, including Python, SQL, and R.
Seamlessly deploy to Observable. Test before you ship, use automatic deploy-on-commit, and ensure your projects are always up-to-date.
Learn more