Skip to content

Commit

Permalink
fix(extrapolation): support zero variance inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrorrivero committed Feb 5, 2024
1 parent 6ca8866 commit c1f535f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
2 changes: 1 addition & 1 deletion zne/extrapolation/exponential_extrapolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _extrapolate_zero(
self._model,
x_data,
y_data,
sigma=sigma_y,
sigma=self._compute_sigma(y_data, sigma_y),
absolute_sigma=True,
p0=[2 ** (-i) for i in range(self.num_terms * 2 + 1)],
bounds=([-inf] + [-inf, 0] * self.num_terms, inf), # Note: only decay considered
Expand Down
18 changes: 17 additions & 1 deletion zne/extrapolation/extrapolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from collections import namedtuple
from collections.abc import Sequence

from numpy import array
from numpy import float_ as npfloat
from numpy import mean, ndarray
from numpy import isclose, mean, ndarray, ones

from zne.types import Metadata
from zne.utils.strategy import strategy
Expand Down Expand Up @@ -161,6 +162,21 @@ def _model(self, x, *coefficients) -> ndarray: # pylint: disable=invalid-name
################################################################################
## AUXILIARY
################################################################################
def _compute_sigma(
self,
y_data: tuple[float, ...],
sigma_y: tuple[float, ...],
) -> ndarray:
"""Compute sensible sigma values for curve fitting.
This implementation bypasses zero effective variance which would
lead to numerical errors in the curve fitting procedure.
"""
values = array(y_data)
errors = array(sigma_y)
relative_errors = errors / values
return ones(errors.shape) if any(isclose(relative_errors, 0)) else errors

def _build_metadata(
self,
x_data: ndarray,
Expand Down
2 changes: 1 addition & 1 deletion zne/extrapolation/polynomial_extrapolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _infer(self, target: float, regression_data: _RegressionData) -> ReckoningRe
self._model,
regression_data.x_data,
regression_data.y_data,
sigma=regression_data.sigma_y,
sigma=self._compute_sigma(regression_data.y_data, regression_data.sigma_y),
absolute_sigma=True,
p0=zeros(self.degree + 1), # Note: Initial point determines number of d.o.f.
)
Expand Down

0 comments on commit c1f535f

Please sign in to comment.