function* slicedOptimalTransport(
src,
tgt,
{ maxIterations = 100, batchSize = 4, tolerance = 1 } = {}
) {
if (src.length !== tgt.length)
throw new Error("Source size must equal target size");
mutable history = [];
const N = src.length >> 2;
const index = new Uint32Array(N);
const srcProjection = new Float32Array(N);
const tgtProjection = new Float32Array(N);
const adjustment = new Float32Array(N * 3);
let delta = Infinity;
let iteration = 0;
while (delta > tolerance && ++iteration <= maxIterations) {
adjustment.fill(0);
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
let [v0, v1, v2] = vec3normalize([], [randn(), randn(), randn()]);
for (let i = 0, i4 = 0; i < N; i++, i4 += 4) {
index[i] = i;
srcProjection[i] = v0 * src[i4] + v1 * src[i4 + 1] + v2 * src[i4 + 2];
tgtProjection[i] = v0 * tgt[i4] + v1 * tgt[i4 + 1] + v2 * tgt[i4 + 2];
}
sort(srcProjection, index, 0, N - 1);
tgtProjection.sort();
for (let j = 0; j < N; j++) {
const projectedDiff = tgtProjection[j] - srcProjection[j];
const i3 = index[j] * 3;
adjustment[i3 + 0] += v0 * projectedDiff;
adjustment[i3 + 1] += v1 * projectedDiff;
adjustment[i3 + 2] += v2 * projectedDiff;
}
}
delta = 0;
for (let i3 = 0, i4 = 0; i4 < N * 4; i3 += 3, i4 += 4) {
const dr = adjustment[i3] / batchSize;
const dg = adjustment[i3 + 1] / batchSize;
const db = adjustment[i3 + 2] / batchSize;
src[i4] += dr;
src[i4 + 1] += dg;
src[i4 + 2] += db;
delta += dr * dr + dg * dg + db * db;
}
delta = Math.sqrt(delta / N);
mutable history = mutable history.concat([{ iteration, delta }]);
yield { delta, src };
}
return { delta, src };
}