Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3603: reorder tweakreg to reduce iterations through models #8424

Merged
merged 15 commits into from
May 18, 2024
Merged
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ tweakreg

- Improved how a image group name is determined. [#8426]
braingram marked this conversation as resolved.
Show resolved Hide resolved

- Refactor step to work towards performance improvements. [#8424]

- Changed default settings for ``abs_separation`` parameter for the ``tweakreg``
step to have a value compatible with the ``abs_tolerance`` parameter. [#8445]

Expand Down
5 changes: 3 additions & 2 deletions jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) ->

def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=None,
pscale_ratio=None, pscale=None, rotation=None,
shape=None, crpix=None, crval=None):
shape=None, crpix=None, crval=None, wcslist=None):
"""
Create a WCS from a list of input data models.

Expand Down Expand Up @@ -259,7 +259,8 @@ def wcs_from_footprints(dmodels, refmodel=None, transform=None, bounding_box=Non

"""
bb = bounding_box
wcslist = [im.meta.wcs for im in dmodels]
if wcslist is None:
wcslist = [im.meta.wcs for im in dmodels]

if not isiterable(wcslist):
raise ValueError("Expected 'wcslist' to be an iterable of WCS objects.")
Expand Down
17 changes: 8 additions & 9 deletions jwst/datamodels/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(self, init=None, asn_exptypes=None, asn_n_members=None,
self.asn_table = {}
self.asn_table_name = None
self.asn_pool_name = None
self.asn_file_path = None

self._memmap = kwargs.get("memmap", False)
self._return_open = kwargs.get('return_open', True)
Expand Down Expand Up @@ -196,7 +197,8 @@ def __init__(self, init=None, asn_exptypes=None, asn_n_members=None,
self.from_asn(init)
elif isinstance(init, str):
init_from_asn = self.read_asn(init)
self.from_asn(init_from_asn, asn_file_path=init)
self.asn_file_path = init
self.from_asn(init_from_asn)
else:
raise TypeError('Input {0!r} is not a list of JwstDataModels or '
'an ASN file'.format(init))
Expand Down Expand Up @@ -275,17 +277,14 @@ def read_asn(filepath):
raise IOError("Cannot read ASN file.") from e
return asn_data

def from_asn(self, asn_data, asn_file_path=None):
def from_asn(self, asn_data):
"""
Load fits files from a JWST association file.

Parameters
----------
asn_data : ~jwst.associations.Association
An association dictionary

asn_file_path: str
Filepath of the association, if known.
"""
# match the asn_exptypes to the exptype in the association and retain
# only those file that match, as a list, if asn_exptypes is set to none
Expand All @@ -303,8 +302,8 @@ def from_asn(self, asn_data, asn_file_path=None):
infiles = [member for member
in asn_data['products'][0]['members']]

if asn_file_path:
asn_dir = op.dirname(asn_file_path)
if self.asn_file_path:
asn_dir = op.dirname(self.asn_file_path)
else:
asn_dir = ''

Expand Down Expand Up @@ -348,8 +347,8 @@ def from_asn(self, asn_data, asn_file_path=None):
self.meta.asn_table._instance, asn_data
)

if asn_file_path is not None:
self.asn_table_name = op.basename(asn_file_path)
if self.asn_file_path is not None:
self.asn_table_name = op.basename(self.asn_file_path)
self.asn_pool_name = asn_data['asn_pool']
for model in self:
try:
Expand Down
270 changes: 266 additions & 4 deletions jwst/tweakreg/tests/test_tweakreg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from copy import deepcopy
import json
import os

import asdf
from astropy.modeling.models import Shift
from astropy.table import Table
import numpy as np
import pytest

from jwst.tweakreg import tweakreg_step
from jwst.tweakreg import tweakreg_catalog
from jwst.tweakreg.utils import _wcsinfo_from_wcs_transform
from stdatamodels.jwst.datamodels import ImageModel
from jwst.datamodels import ModelContainer


BKG_LEVEL = 0.001
N_EXAMPLE_SOURCES = 21
N_CUSTOM_SOURCES = 15


@pytest.fixture
Expand All @@ -21,8 +30,55 @@ def dummy_source_catalog():
return catalog


@pytest.mark.parametrize("inplace", [True, False])
def test_rename_catalog_columns(dummy_source_catalog, inplace):
hbushouse marked this conversation as resolved.
Show resolved Hide resolved
"""
Test that a catalog with 'xcentroid' and 'ycentroid' columns
passed to _renamed_catalog_columns successfully renames those columns
to 'x' and 'y' (and does so "inplace" modifying the input catalog)
"""
renamed_catalog = tweakreg_step._rename_catalog_columns(dummy_source_catalog)

# if testing inplace, check the input catalog
if inplace:
catalog = dummy_source_catalog
else:
catalog = renamed_catalog

assert 'xcentroid' not in catalog.colnames
assert 'ycentroid' not in catalog.colnames
assert 'x' in catalog.colnames
assert 'y' in catalog.colnames
hbushouse marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("missing", ["x", "y", "xcentroid", "ycentroid"])
def test_rename_catalog_columns_invalid(dummy_source_catalog, missing):
hbushouse marked this conversation as resolved.
Show resolved Hide resolved
"""
Test that passing a catalog that is missing either "x" or "y"
(or "xcentroid" and "ycentroid" which is renamed to "x" or "y")
results in an exception indicating that a required column is missing
"""
# if the column we want to remove is not in the table, first run
# rename to rename columns this should add the column we want to remove
if missing not in dummy_source_catalog.colnames:
tweakreg_step._rename_catalog_columns(dummy_source_catalog)
dummy_source_catalog.remove_column(missing)
with pytest.raises(ValueError, match="catalogs must contain"):
tweakreg_step._rename_catalog_columns(dummy_source_catalog)


@pytest.mark.parametrize("offset, is_good", [(1 / 3600, True), (11 / 3600, False)])
def test_is_wcs_correction_small(offset, is_good):
hbushouse marked this conversation as resolved.
Show resolved Hide resolved
"""
Test that the _is_wcs_correction_small method returns True for a small
wcs correction and False for a "large" wcs correction. The values in this
test are selected based on the current step default parameters:
- use2dhist
- searchrad
- tolerance
Changes to the defaults for these parameters will likely require updating the
values uses for parametrizing this test.
"""
path = os.path.join(os.path.dirname(__file__), "mosaic_long_i2d_gwcs.asdf")
with asdf.open(path) as af:
wcs = af.tree["wcs"]
Expand All @@ -35,7 +91,18 @@ def test_is_wcs_correction_small(offset, is_good):

step = tweakreg_step.TweakRegStep()

assert step._is_wcs_correction_small(wcs, twcs) == is_good
class FakeCorrector:
def __init__(self, wcs, original_skycoord):
self.wcs = wcs
self._original_skycoord = original_skycoord

@property
def meta(self):
return {'original_skycoord': self._original_skycoord}

correctors = [FakeCorrector(twcs, tweakreg_step._wcs_to_skycoord(wcs))]

assert step._is_wcs_correction_small(correctors) == is_good


def test_expected_failure_bad_starfinder():
Expand All @@ -51,11 +118,206 @@ def test_write_catalog(dummy_source_catalog, tmp_cwd):
'''

OUTDIR = 'outdir'
model = ImageModel()
step = tweakreg_step.TweakRegStep()
os.mkdir(OUTDIR)
step.output_dir = OUTDIR
expected_outfile = os.path.join(OUTDIR, 'catalog.ecsv')
step._write_catalog(model, dummy_source_catalog, 'catalog.ecsv')
step._write_catalog(dummy_source_catalog, 'catalog.ecsv')

assert os.path.exists(expected_outfile)


@pytest.fixture()
def example_wcs():
path = os.path.join(
os.path.dirname(__file__),
"data",
"nrcb1-wcs.asdf")
with asdf.open(path, lazy_load=False) as af:
return af.tree["wcs"]


@pytest.fixture()
def example_input(example_wcs):
m0 = ImageModel((512, 512))

# add a wcs and wcsinfo
m0.meta.wcs = example_wcs
m0.meta.wcsinfo = _wcsinfo_from_wcs_transform(example_wcs)

# and a few 'sources'
m0.data[:] = BKG_LEVEL
n_sources = N_EXAMPLE_SOURCES # a few more than default minobj
rng = np.random.default_rng(26)
xs = rng.choice(50, n_sources, replace=False) * 8 + 10
ys = rng.choice(50, n_sources, replace=False) * 8 + 10
for y, x in zip(ys, xs):
m0.data[y-1:y+2, x-1:x+2] = [
[0.1, 0.6, 0.1],
[0.6, 0.8, 0.6],
[0.1, 0.6, 0.1],
]

m1 = m0.copy()
# give each a unique filename
m0.meta.filename = 'some_file_0.fits'
m1.meta.filename = 'some_file_1.fits'
c = ModelContainer([m0, m1])
return c


@pytest.mark.parametrize("with_shift", [True, False])
def test_tweakreg_step(example_input, with_shift):
"""
A simplified unit test for basic operation of the TweakRegStep
when run with or without a small shift in the input image sources
"""
if with_shift:
# shift 9 pixels so that the sources in one of the 2 images
# appear at different locations (resulting in a correct wcs update)
example_input[1].data[:-9] = example_input[1].data[9:]
example_input[1].data[-9:] = BKG_LEVEL

# assign images to different groups (so they are aligned to each other)
example_input[0].meta.group_id = 'a'
example_input[1].meta.group_id = 'b'

# make the step with default arguments
step = tweakreg_step.TweakRegStep()

# run the step on the example input modified above
result = step(example_input)

# check that step completed
for model in result:
assert model.meta.cal_step.tweakreg == 'COMPLETE'

# and that the wcses differ by a small amount due to the shift above
# by projecting one point through each wcs and comparing the difference
abs_delta = abs(result[1].meta.wcs(0, 0)[0] - result[0].meta.wcs(0, 0)[0])
if with_shift:
assert abs_delta > 1E-5
else:
assert abs_delta < 1E-12


@pytest.fixture()
def custom_catalog_path(tmp_path):
fn = tmp_path / "custom_catalog.ecsv"

# it's important that the sources here don't match
# those added in example_input but conform to the input
# shape, wcs, etc used in example_input
rng = np.random.default_rng(42)
n_sources = N_CUSTOM_SOURCES
xs = rng.choice(50, n_sources, replace=False) * 8 + 10
ys = rng.choice(50, n_sources, replace=False) * 8 + 10
catalog = Table(np.vstack((xs, ys)).T, names=['x', 'y'], dtype=[float, float])
catalog.write(fn)
return fn


@pytest.mark.parametrize(
"catfile",
["no_catfile", "valid_catfile", "invalid_catfile", "empty_catfile_row"],
)
@pytest.mark.parametrize(
"asn",
["no_cat_in_asn", "cat_in_asn", "empty_asn_entry"],
)
@pytest.mark.parametrize(
"meta",
["no_meta", "cat_in_meta", "empty_meta"],
)
@pytest.mark.parametrize("custom", [True, False])
@pytest.mark.slow
def test_custom_catalog(custom_catalog_path, example_input, catfile, asn, meta, custom, monkeypatch):
mairanteodoro marked this conversation as resolved.
Show resolved Hide resolved
"""
Test that TweakRegStep uses a custom catalog provided by the user
when the correct set of options are provided. The combinations here can be confusing
and this test attempts to test all likely combinations of:
- a catalog in a `catfile`
- a catalog in the asn
- a catalog in the metadata
combined with step options:
- `use_custom_catalogs` (True/False)
- a "valid" file passed as `catfile`
"""
example_input[0].meta.group_id = 'a'
example_input[1].meta.group_id = 'b'

# this worked because if use_custom_catalogs was true but
# catfile was blank tweakreg still uses custom catalogs
# which in this case is defined in model.meta.tweakreg_catalog
if meta == "cat_in_meta":
example_input[0].meta.tweakreg_catalog = str(custom_catalog_path)
elif meta == "empty_meta":
example_input[0].meta.tweakreg_catalog = ""

# write out the ModelContainer and association (so the association table will be loaded)
example_input.save(dir_path=str(custom_catalog_path.parent))
asn_data = {
'asn_id': 'foo',
'asn_pool': 'bar',
'products': [
{
'members': [{'expname': m.meta.filename, 'exptype': 'science'} for m in example_input],
},
],
}

if asn == "empty_asn_entry":
asn_data['products'][0]['members'][0]['tweakreg_catalog'] = ''
elif asn == "cat_in_asn":
asn_data['products'][0]['members'][0]['tweakreg_catalog'] = str(custom_catalog_path.name)

asn_path = custom_catalog_path.parent / 'example_input.json'
with open(asn_path, 'w') as f:
json.dump(asn_data, f)

# write out a catfile
if catfile != "no_catfile":
catfile_path = custom_catalog_path.parent / 'catfile.txt'
with open(catfile_path, 'w') as f:
if catfile == "valid_catfile":
f.write(f"{example_input[0].meta.filename} {custom_catalog_path.name}")
elif catfile == "empty_catfile_row":
f.write(f"{example_input[0].meta.filename}")
elif catfile == "invalid_catfile":
pass

# figure out how many sources to expect for the model in group 'a'
n_custom_sources = N_EXAMPLE_SOURCES
if custom:
if catfile == "valid_catfile":
# for a 'valid' catfile, expect the custom number
n_custom_sources = N_CUSTOM_SOURCES
elif catfile == "no_catfile":
# since catfile is not defined, now look at asn_
if asn == "cat_in_asn":
# for a 'valid' asn entry, expect the custom number
n_custom_sources = N_CUSTOM_SOURCES
elif asn == "no_cat_in_asn" and meta == "cat_in_meta":
n_custom_sources = N_CUSTOM_SOURCES

kwargs = {'use_custom_catalogs': custom}
if catfile != "no_catfile":
kwargs["catfile"] = str(catfile_path)
step = tweakreg_step.TweakRegStep(**kwargs)

# patch _construct_wcs_corrector to check the correct catalog was loaded
def patched_construct_wcs_corrector(model, catalog, _seen=[]):
# we don't need to continue
if model.meta.group_id == 'a':
assert len(catalog) == n_custom_sources
elif model.meta.group_id == 'b':
assert len(catalog) == N_EXAMPLE_SOURCES
_seen.append(model)
if len(_seen) == 2:
raise ValueError("done testing")
return None

monkeypatch.setattr(tweakreg_step, "_construct_wcs_corrector", patched_construct_wcs_corrector)

assert os.path.exists(expected_outfile)
with pytest.raises(ValueError, match="done testing"):
step(str(asn_path))
Loading
Loading