Skip to content

Commit

Permalink
nocopy ops
Browse files Browse the repository at this point in the history
  • Loading branch information
fredyshox committed Aug 31, 2024
1 parent 05fd602 commit 294373a
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions rednose/helpers/ekf_sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def augment(self):
assert self.P.shape == (self.dim_err, self.dim_err)

def state(self):
return np.array(self.x).flatten()
return np.array(self.x).ravel()

def covs(self):
return self.P
Expand Down Expand Up @@ -461,27 +461,26 @@ def predict(self, t):
self.normalize_quaternions()
self.filter_time = t

def predict_and_update_batch(self, t, kind, z, R, extra_args=[[]], augment=False): # pylint: disable=dangerous-default-value
def predict_and_update_batch(self, t, kind, z, R, xk_km1=None, Pk_km1=None, xk_k=None, Pk_k=None, y=None, extra_args=[[]], augment=False): # pylint: disable=dangerous-default-value
# TODO handle rewinding at this level"

# rewind
if self.filter_time is not None and t < self.filter_time:
if len(self.rewind_t) == 0 or t < self.rewind_t[0] or t < self.rewind_t[-1] - self.max_rewind_age:
self.logger.error(f"observation too old at {t:.3f} with filter at {self.filter_time:.3f}, ignoring")
return None
return False
rewound = self.rewind(t)
else:
rewound = []

ret = self._predict_and_update_batch(t, kind, z, R, extra_args, augment)
self._predict_and_update_batch(t, kind, z, R, extra_args, xk_km1, Pk_km1, xk_k, Pk_k, y, augment)

# optional fast forward
for r in rewound:
self._predict_and_update_batch(*r)
return True

return ret

def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False):
def _predict_and_update_batch(self, t, kind, z, R, extra_args, xk_km1=None, Pk_km1=None, xk_k=None, Pk_k=None, y=None, augment=False):
"""The main kalman filter function
Predicts the state and then updates a batch of observations
dim_x: dimensionality of the state space
Expand All @@ -497,6 +496,8 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False):
assert z.shape[0] == R.shape[0]
assert z.shape[1] == R.shape[1]
assert z.shape[1] == R.shape[2]
assert y is None or z.shape == y.shape
# assert stuff

# initialize time
if self.filter_time is None:
Expand All @@ -507,10 +508,13 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False):
assert dt >= 0
self.x, self.P = self._predict(self.x, self.P, dt)
self.filter_time = t
xk_km1, Pk_km1 = np.copy(self.x).flatten(), np.copy(self.P)

if xk_km1 is not None:
np.copyto(xk_km1, self.x.ravel())
if Pk_km1 is not None:
np.copyto(Pk_km1, self.P)

# update batch
y = []
for i in range(len(z)):
# these are from the user, so we canonicalize them
z_i = np.array(z[i], dtype=np.float64, order='F')
Expand All @@ -519,17 +523,20 @@ def _predict_and_update_batch(self, t, kind, z, R, extra_args, augment=False):
# update
self.x, self.P, y_i = self._update(self.x, self.P, kind, z_i, R_i, extra_args=extra_args_i)
self.normalize_quaternions()
y.append(y_i)
xk_k, Pk_k = np.copy(self.x).flatten(), np.copy(self.P)
if y is not None:
np.copyto(y_i, y[i])

if xk_k is not None:
np.copyto(xk_k, self.x.ravel())
if Pk_k is not None:
np.copyto(Pk_k, self.P)

if augment:
self.augment()

# checkpoint
self.checkpoint((t, kind, z, R, extra_args))

return xk_km1, xk_k, Pk_km1, Pk_k, t, kind, y, z, extra_args

def _predict_python(self, x, P, dt):
x_new = np.zeros(x.shape, dtype=np.float64)
self.f(x, dt, x_new)
Expand Down

0 comments on commit 294373a

Please sign in to comment.