diff --git a/src/filters.py b/src/filters.py index a1d0e156b83a77229cc03db41b6c1a6df43d0e0f..bcd2c09577b44ab65708283756c8eddaf7c7f15a 100644 --- a/src/filters.py +++ b/src/filters.py @@ -1,5 +1,6 @@ """Contains filter implementations. """ +import pdb import numpy as np import scipy.stats as stats import jax @@ -111,10 +112,11 @@ class IMM: """ # Calc. mixing probabilities mu = self.transition_prob*self.mode_probabilities[-1][:, None] - mu = mu/np.sum(mu, axis=0) # Normalize + mu = mu/np.sum(mu, axis=1, keepdims=True) # Normalize # Mixed state - xp = np.hstack([x@mu[:, [k]] for k in range(self.nmodes)]) + xp = np.hstack([x@mu[[k], :].T for k in range(self.nmodes)]) Pp = np.zeros(P.shape) + for i in range(self.nmodes): xdiff = x-xp[:, [i]] covxp = np.stack([xdiff[:, [k]]@xdiff[:, [k]].T for k in range(self.nmodes)], axis=2)