class UnscentedKalmanFilter {
constructor(n, initialState, initialCovariance, alpha, beta, kappa) {
this.n = n
this.x = initialState
this.P = initialCovariance
this.alpha = alpha
this.beta = beta
this.kappa = kappa
const lambda = Math.pow(this.alpha, 2) * (this.n + this.kappa) - this.n;
this.Wm = tf.concat(
[
tf.scalar(lambda / (this.n + lambda)).reshape([1, 1]),
tf.ones([1, 2 * this.n]).mul(1 / (2 * (this.n + lambda))),
],
1,
);
this.Wc = this.Wm.clone().add(
tf.scalar(this.beta - 1).mul(tf.square(this.alpha)),
);
}
sigmaPoints() {
const lambda = Math.pow(this.alpha, 2) * (this.n + this.kappa) - this.n
const sqrtScaledP = this.choleskyDecomposition(tf.mul(this.P, this.n + lambda))
const positiveSigmaPoints = tf.add(this.x, sqrtScaledP)
const negativeSigmaPoints = tf.sub(this.x, sqrtScaledP)
const sigmaPoints = tf.concat([this.x, positiveSigmaPoints, negativeSigmaPoints], 1)
return sigmaPoints
}
unscentedTransform(sigmaPoints, Wm, Wc) {
// Calculate mean
const mean = tf.sum(tf.mul(sigmaPoints, Wm), 1, true);
// Calculate covariance
const residual = tf.sub(sigmaPoints, tf.expandDims(mean, 1));
const weightedResiduals = tf.mul(residual, tf.sqrt(Wc));
const covariance = tf.matMul(weightedResiduals, weightedResiduals.transpose());
return [mean, covariance];
}
predict(f, Q) {
const lambda = Math.pow(this.alpha, 2) * (this.n + this.kappa) - this.n
const sigmaPoints = this.sigmaPoints()
// Predict sigma points
const predictedSigmaPoints = f(sigmaPoints)
// Perform unscented transform
const [predictedMean, predictedCovariance] = this.unscentedTransform(
predictedSigmaPoints,
this.Wm,
this.Wc
)
// Add process noise
this.x = predictedMean
this.P = tf.add(predictedCovariance, Q)
}
update(h, z, R) {
const lambda = Math.pow(this.alpha, 2) * (this.n + this.kappa) - this.n
const sigmaPoints = this.sigmaPoints()
// Transform sigma points into measurement space
const predictedMeasurements = h(sigmaPoints)
// Perform unscented transform in measurement space
const [predictedMean, predictedCovariance] = this.unscentedTransform(
predictedMeasurements,
this.Wm,
this.Wc
)
// Add measurement noise
const predictedCovarianceWithNoise = tf.add(predictedCovariance, R)
// Calculate cross-covariance between state and measurement
const residualState = tf.sub(sigmaPoints, this.x)
const residualMeasurement = tf.sub(predictedMeasurements, predictedMean)
const crossCovariance = tf.matMul(tf.mul(residualState, this.Wc), residualMeasurement.transpose())
// Calculate Kalman gain
const K = tf.matMul(
crossCovariance,
this.matrixInverse(predictedCovarianceWithNoise)
)
// Update state and covariance
const measurementResidual = tf.sub(z, predictedMean)
this.x = tf.add(this.x, tf.matMul(K, measurementResidual))
this.P = tf.sub(
this.P,
tf.matMul(tf.matMul(K, predictedCovarianceWithNoise), K.transpose())
)
}
choleskyDecomposition(A) {
const n = A.shape[0]
const L = tf.buffer([n, n], 'float32')
const Abuf = A.bufferSync()
for (let i = 0; i < n; i++) {
Abuf.set(Abuf.get(i, i) + 1e-6, i, i) // Add a small positive value to the diagonal
for (let j = 0; j < i + 1; j++) {
let sum = 0
for (let k = 0; k < j; k++) {
sum += L.get(i, k) * L.get(j, k)
}
if (i === j) {
L.set(Math.sqrt(Abuf.get(i, i) - sum), i, j)
} else {
L.set((1.0 / L.get(j, j)) * (Abuf.get(i, j) - sum), i, j)
}
}
}
return L.toTensor()
}
matrixInverse(A) {
const n = A.shape[0]
const AI = tf.concat([A, tf.eye(n)], 1)
const rref = tf.buffer(AI.shape, "float32")
const AIArray = AI.arraySync()
for (let i = 0; i < n; i++) {
let maxIndex = i
for (let j = i + 1; j < n; j++) {
if (Math.abs(AIArray[j][i]) > Math.abs(AIArray[maxIndex][i])) {
maxIndex = j
}
}
if (maxIndex !== i) {
;[AIArray[i], AIArray[maxIndex]] = [AIArray[maxIndex], AIArray[i]]
}
const pivot = AIArray[i][i]
for (let j = i; j < 2 * n; j++) {
AIArray[i][j] /= pivot
}
for (let j = 0; j < n; j++) {
if (j !== i) {
const factor = AIArray[j][i]
for (let k = i; k < 2 * n; k++) {
AIArray[j][k] -= factor * AIArray[i][k]
}
}
}
}
for (let i = 0; i < n; i++) {
for (let j = 0; j < 2 * n; j++) {
rref.set(AIArray[i][j], i, j)
}
}
return rref.toTensor().slice([0, n], [n, n])
}
}