From c1f535ff7ac1236a94be243b59549e9160cd2481 Mon Sep 17 00:00:00 2001 From: Pedro Rivero Date: Mon, 5 Feb 2024 17:36:05 -0500 Subject: [PATCH] fix(extrapolation): support zero variance inputs --- zne/extrapolation/exponential_extrapolator.py | 2 +- zne/extrapolation/extrapolator.py | 18 +++++++++++++++++- zne/extrapolation/polynomial_extrapolator.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/zne/extrapolation/exponential_extrapolator.py b/zne/extrapolation/exponential_extrapolator.py index 8e69398..e8360fd 100644 --- a/zne/extrapolation/exponential_extrapolator.py +++ b/zne/extrapolation/exponential_extrapolator.py @@ -139,7 +139,7 @@ def _extrapolate_zero( self._model, x_data, y_data, - sigma=sigma_y, + sigma=self._compute_sigma(y_data, sigma_y), absolute_sigma=True, p0=[2 ** (-i) for i in range(self.num_terms * 2 + 1)], bounds=([-inf] + [-inf, 0] * self.num_terms, inf), # Note: only decay considered diff --git a/zne/extrapolation/extrapolator.py b/zne/extrapolation/extrapolator.py index 539c6fa..57b19c8 100644 --- a/zne/extrapolation/extrapolator.py +++ b/zne/extrapolation/extrapolator.py @@ -18,8 +18,9 @@ from collections import namedtuple from collections.abc import Sequence +from numpy import array from numpy import float_ as npfloat -from numpy import mean, ndarray +from numpy import isclose, mean, ndarray, ones from zne.types import Metadata from zne.utils.strategy import strategy @@ -161,6 +162,21 @@ def _model(self, x, *coefficients) -> ndarray: # pylint: disable=invalid-name ################################################################################ ## AUXILIARY ################################################################################ + def _compute_sigma( + self, + y_data: tuple[float, ...], + sigma_y: tuple[float, ...], + ) -> ndarray: + """Compute sensible sigma values for curve fitting. + + This implementation bypasses zero effective variance which would + lead to numerical errors in the curve fitting procedure. + """ + values = array(y_data) + errors = array(sigma_y) + relative_errors = errors / values + return ones(errors.shape) if any(isclose(relative_errors, 0)) else errors + def _build_metadata( self, x_data: ndarray, diff --git a/zne/extrapolation/polynomial_extrapolator.py b/zne/extrapolation/polynomial_extrapolator.py index e991f79..6d9fca7 100644 --- a/zne/extrapolation/polynomial_extrapolator.py +++ b/zne/extrapolation/polynomial_extrapolator.py @@ -95,7 +95,7 @@ def _infer(self, target: float, regression_data: _RegressionData) -> ReckoningRe self._model, regression_data.x_data, regression_data.y_data, - sigma=regression_data.sigma_y, + sigma=self._compute_sigma(regression_data.y_data, regression_data.sigma_y), absolute_sigma=True, p0=zeros(self.degree + 1), # Note: Initial point determines number of d.o.f. )