diff --git a/tweakwcs/linearfit.py b/tweakwcs/linearfit.py index 01e2ba5..53ddf7a 100644 --- a/tweakwcs/linearfit.py +++ b/tweakwcs/linearfit.py @@ -10,16 +10,8 @@ """ import logging import numbers -from packaging.version import Version import numpy as np -import astropy -from astropy.modeling.fitting import LevMarLSQFitter -if Version(astropy.__version__) >= Version('5.1'): - from astropy.modeling.fitting import fitter_to_model_params -else: - from astropy.modeling.fitting import (_fitter_to_model_params as - fitter_to_model_params) from . linalg import inv from . import __version__ # noqa: F401 @@ -27,8 +19,9 @@ __author__ = 'Mihai Cara, Warren Hack' __all__ = [ - 'iter_linear_fit', 'build_fit_matrix', 'SUPPORTED_FITGEOM_MODES', - '_LevMarLSQFitter2x2' + 'iter_linear_fit', + 'build_fit_matrix', + 'SUPPORTED_FITGEOM_MODES', ] # Supported fitgeom modes and corresponding minobj @@ -844,14 +837,3 @@ def build_fit_matrix(rot, scale=1): [-sx * np.sin(rx), sy * np.cos(ry)]]) return matrix - - -class _LevMarLSQFitter2x2(LevMarLSQFitter): - """ Performs fits of 2D vector-models to 2D reference points. """ - def objective_function(self, fps, *args): - model, weights, inputs, meas = args - fitter_to_model_params(model, fps) - if weights is None: - return np.ravel(np.subtract(model(*inputs), meas)) - else: - return np.ravel(weights * np.subtract(model(*inputs), meas)) diff --git a/tweakwcs/tests/test_linearfit.py b/tweakwcs/tests/test_linearfit.py index fafdc89..f15a5d3 100644 --- a/tweakwcs/tests/test_linearfit.py +++ b/tweakwcs/tests/test_linearfit.py @@ -7,11 +7,21 @@ from itertools import product import math import sys -import pytest -import numpy as np + +import astropy +from astropy.modeling.fitting import LevMarLSQFitter from astropy.modeling.models import Shift, Rotation2D +import numpy as np +from packaging.version import Version +import pytest from tweakwcs import linearfit, linalg +if Version(astropy.__version__) >= Version('5.1'): + from astropy.modeling.fitting import fitter_to_model_params +else: + from astropy.modeling.fitting import (_fitter_to_model_params as + fitter_to_model_params) + _LARGE_SAMPLE_SIZE = 1000 @@ -427,6 +437,17 @@ def test_iter_rscale_invalid_scale(): linearfit.fit_rscale(np.zeros((4, 2)), np.zeros((4, 2)), scale=0) +class _LevMarLSQFitter2x2(LevMarLSQFitter): + """ Performs fits of 2D vector-models to 2D reference points. """ + def objective_function(self, fps, *args): + model, weights, inputs, meas, *_ = args + fitter_to_model_params(model, fps) + if weights is None: + return np.ravel(np.subtract(model(*inputs), meas)) + else: + return np.ravel(weights * np.subtract(model(*inputs), meas)) + + def test_levmar2x2_multivariate(): inputs = [np.array([10., 10., 20., 20.]), np.array([10., 20., 20., 10.])] outputs = [np.array([8.06101731, 0.98994949, 8.06101731, 15.13208512]), @@ -434,7 +455,7 @@ def test_levmar2x2_multivariate(): rot = Rotation2D() rot.fittable = True model = (Shift() & Shift()) | rot - fitter = linearfit._LevMarLSQFitter2x2() + fitter = _LevMarLSQFitter2x2() finfo = fitter(model, inputs, outputs) assert np.allclose(finfo.parameters, np.array([4.3, -7.1, 45.]), rtol=1e-5, atol=1e-5)