Skip to content

Commit

Permalink
Fix failure related to astropy PR 16673
Browse files Browse the repository at this point in the history
  • Loading branch information
mcara committed Sep 10, 2024
1 parent 400f15b commit adacf67
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
24 changes: 3 additions & 21 deletions tweakwcs/linearfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,18 @@
"""
import logging
import numbers
from packaging.version import Version

import numpy as np
import astropy
from astropy.modeling.fitting import LevMarLSQFitter
if Version(astropy.__version__) >= Version('5.1'):
from astropy.modeling.fitting import fitter_to_model_params
else:
from astropy.modeling.fitting import (_fitter_to_model_params as
fitter_to_model_params)

from . linalg import inv
from . import __version__ # noqa: F401

__author__ = 'Mihai Cara, Warren Hack'

__all__ = [
'iter_linear_fit', 'build_fit_matrix', 'SUPPORTED_FITGEOM_MODES',
'_LevMarLSQFitter2x2'
'iter_linear_fit',
'build_fit_matrix',
'SUPPORTED_FITGEOM_MODES',
]

# Supported fitgeom modes and corresponding minobj
Expand Down Expand Up @@ -844,14 +837,3 @@ def build_fit_matrix(rot, scale=1):
[-sx * np.sin(rx), sy * np.cos(ry)]])

return matrix


class _LevMarLSQFitter2x2(LevMarLSQFitter):
""" Performs fits of 2D vector-models to 2D reference points. """
def objective_function(self, fps, *args):
model, weights, inputs, meas = args
fitter_to_model_params(model, fps)
if weights is None:
return np.ravel(np.subtract(model(*inputs), meas))
else:
return np.ravel(weights * np.subtract(model(*inputs), meas))
27 changes: 24 additions & 3 deletions tweakwcs/tests/test_linearfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@
from itertools import product
import math
import sys
import pytest
import numpy as np

import astropy
from astropy.modeling.fitting import LevMarLSQFitter
from astropy.modeling.models import Shift, Rotation2D
import numpy as np
from packaging.version import Version
import pytest
from tweakwcs import linearfit, linalg

if Version(astropy.__version__) >= Version('5.1'):
from astropy.modeling.fitting import fitter_to_model_params
else:
from astropy.modeling.fitting import (_fitter_to_model_params as
fitter_to_model_params)


_LARGE_SAMPLE_SIZE = 1000

Expand Down Expand Up @@ -427,14 +437,25 @@ def test_iter_rscale_invalid_scale():
linearfit.fit_rscale(np.zeros((4, 2)), np.zeros((4, 2)), scale=0)


class _LevMarLSQFitter2x2(LevMarLSQFitter):
""" Performs fits of 2D vector-models to 2D reference points. """
def objective_function(self, fps, *args):
model, weights, inputs, meas, *_ = args
fitter_to_model_params(model, fps)
if weights is None:
return np.ravel(np.subtract(model(*inputs), meas))
else:
return np.ravel(weights * np.subtract(model(*inputs), meas))


def test_levmar2x2_multivariate():
inputs = [np.array([10., 10., 20., 20.]), np.array([10., 20., 20., 10.])]
outputs = [np.array([8.06101731, 0.98994949, 8.06101731, 15.13208512]),
np.array([12.16223664, 19.23330445, 26.30437226, 19.23330445])]
rot = Rotation2D()
rot.fittable = True
model = (Shift() & Shift()) | rot
fitter = linearfit._LevMarLSQFitter2x2()
fitter = _LevMarLSQFitter2x2()
finfo = fitter(model, inputs, outputs)
assert np.allclose(finfo.parameters, np.array([4.3, -7.1, 45.]),
rtol=1e-5, atol=1e-5)

0 comments on commit adacf67

Please sign in to comment.