diff --git a/src/pupil_labs/neon_recording/stream/stream.py b/src/pupil_labs/neon_recording/stream/stream.py index 1d1bce0..e58a018 100644 --- a/src/pupil_labs/neon_recording/stream/stream.py +++ b/src/pupil_labs/neon_recording/stream/stream.py @@ -9,8 +9,15 @@ class InterpolationMethod(Enum): NEAREST = "nearest" + NEAREST_BEFORE = "nearest_before" LINEAR = "linear" + def __eq__(self, other): + if isinstance(other, str): + return self.value == other + + return super().__eq__(other) + def _record_truthiness(self): for field in self.dtype.names: @@ -39,10 +46,27 @@ def sample(self, tstamps=None, method=InterpolationMethod.NEAREST): if method == InterpolationMethod.NEAREST: return self._sample_nearest(tstamps) + if method == InterpolationMethod.NEAREST_BEFORE: + return self._sample_nearest_before(tstamps) + elif method == InterpolationMethod.LINEAR: return self._sample_linear_interp(tstamps) def _sample_nearest(self, ts): + # Use searchsorted to get the insertion points + idxs = np.searchsorted(self.ts, ts) + + # Ensure index bounds are valid + idxs = np.clip(idxs, 1, len(self.ts) - 1) + left = self.ts[idxs - 1] + right = self.ts[idxs] + + # Determine whether the left or right value is closer + idxs -= (np.abs(ts - left) < np.abs(ts - right)).astype(int) + + return self.sampler_class(self._data[idxs]) + + def _sample_nearest_before(self, ts): last_idx = len(self._data) - 1 idxs = np.searchsorted(self.ts, ts) idxs[idxs > last_idx] = last_idx