diff --git a/CHANGES.rst b/CHANGES.rst index fbec6e44cf..1c8995df2c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,12 @@ ami_average - Fix error in step spec that prevents step creation. [#8677] +assign_mtwcs +------------ + +- Step now uses `ModelLibrary` to handle accessing models consistently + whether they are in memory or on disk. [#8683] + assign_wcs ---------- @@ -34,6 +40,12 @@ cube_build - Fixed a bug when ``cube_build`` was called from the ``mrs_imatch`` step. [#8728] +datamodels +---------- + +- Added `ModelLibrary` class to allow passing on-disk models between steps in the + image3 pipeline. [#8683] + documentation ------------- @@ -84,11 +96,27 @@ outlier_detection images. Intermediate files now have suffix ``outlier_s2d`` and are saved to the output directory alongside final products. [#8735] +- For imaging modes, step now uses `ModelLibrary` to handle accessing models consistently + whether they are in memory or on disk. [#8683] + set_telescope_pointing ---------------------- - replace usage of ``copy_arrays=True`` with ``memmap=False`` [#8660] +pipeline +-------- + +- Updated `calwebb_image3` to use `ModelLibrary` instead of `ModelContainer`, added + optional `on_disk` parameter to govern whether models in the library should be stored + in memory or on disk. [#8683] + +resample +-------- + +- Step now uses `ModelLibrary` to handle accessing models consistently + whether they are in memory or on disk. [#8683] + resample_spec ------------- @@ -123,6 +151,12 @@ scripts - Removed many non-working and out-dated scripts. Including many scripts that were replaced by ``strun``. [#8619] +skymatch +-------- + +- Step now uses `ModelLibrary` to handle accessing models consistently + whether they are in memory or on disk. [#8683] + stpipe ------ @@ -150,6 +184,9 @@ tweakreg - Removed direct setting of the ``self.skip`` attribute from within the step itself. [#8600] +- Step now uses `ModelLibrary` to handle accessing models consistently + whether they are in memory or on disk. [#8683] + 1.15.1 (2024-07-08) =================== diff --git a/docs/jwst/outlier_detection/outlier_detection_imaging.rst b/docs/jwst/outlier_detection/outlier_detection_imaging.rst index ced7d9599f..15df877939 100644 --- a/docs/jwst/outlier_detection/outlier_detection_imaging.rst +++ b/docs/jwst/outlier_detection/outlier_detection_imaging.rst @@ -15,11 +15,11 @@ Specifically, this routine performs the following operations: #. Convert input data, as needed, to make sure it is in a format that can be processed. - * A :py:class:`~jwst.datamodels.ModelContainer` serves as the basic format for + * A :py:class:`~jwst.datamodels.ModelLibrary` serves as the basic format for all processing performed by this step, as each entry will be treated as an element of a stack of images to be processed to identify bad-pixels/cosmic-rays and other artifacts. - * If the input data is a :py:class:`~jwst.datamodels.CubeModel`, convert it into a ModelContainer. + * If the input data is a :py:class:`~jwst.datamodels.CubeModel`, convert it into a ModelLibrary. This allows each plane of the cube to be treated as a separate 2D image for resampling (if done) and for combining into a median image. @@ -62,13 +62,13 @@ Specifically, this routine performs the following operations: if the input model container has an , otherwise the suffix will be ``_outlier_i2d.fits`` by default. * **If resampling is turned off** through the use of the ``resample_data`` parameter, - a copy of the unrectified input images (as a ModelContainer) + a copy of the unrectified input images (as a ModelLibrary) will be used for subsequent processing. #. Create a median image from all grouped observation mosaics. * The median image is created by combining all grouped mosaic images or - non-resampled input data (as planes in a ModelContainer) pixel-by-pixel. + non-resampled input data (as planes in a ModelLibrary) pixel-by-pixel. * The ``maskpt`` parameter sets the percentage of the weight image values to use, and any pixel with a weight below this value gets flagged as "bad" and ignored when resampled. @@ -129,7 +129,7 @@ The outlier detection algorithm can end up using massive amounts of memory depending on the number of inputs, the size of each input, and the size of the final output product. Specifically, -#. The input :py:class:`~jwst.datamodels.ModelContainer` or +#. The input :py:class:`~jwst.datamodels.ModelLibrary` or :py:class:`~jwst.datamodels.CubeModel` for IFU data, by default, all input exposures would have been kept open in memory to make processing more efficient. @@ -152,56 +152,20 @@ memory usage at the expense of file I/O. The control over this memory model hap with the use of the ``in_memory`` parameter. The full impact of this parameter during processing includes: -#. The ``save_open`` parameter gets set to `False` - when opening the input :py:class:`~jwst.datamodels.ModelContainer` object. - This forces all input models in the input :py:class:`~jwst.datamodels.ModelContainer` or - :py:class:`~jwst.datamodels.CubeModel` to get written out to disk. The ModelContainer - then uses the filename of the input model during subsequent processing. +#. The input :py:class:`~jwst.datamodels.ModelLibrary` object is loaded with `on_disk=True`. + This ensures that input models are loaded into memory one at at time, + and saved to a temporary file when not in use; these read-write operations are handled by + the :py:class:`~jwst.datamodels.ModelLibrary` object. -#. The ``in_memory`` parameter gets passed to the :py:class:`~jwst.resample.ResampleStep` - to set whether or not to keep the resampled images in memory or not. By default, - the outlier detection processing sets this parameter to `False` so that each resampled - image gets written out to disk. +#. The ``on_disk`` status of the :py:class:`~jwst.datamodels.ModelLibrary` gets passed to the + :py:class:`~jwst.resample.ResampleStep` as well, to set whether or not to keep the + resampled images in memory or not. #. Computing the median image works section-by-section by only keeping 1Mb of each input in memory at a time. As a result, only the final output product array for the final median image along with a stack of 1Mb image sections are kept in memory. -#. The final resampling step also avoids keeping all inputs in memory by only reading - each input into memory 1 at a time as it gets resampled onto the final output product. - These changes result in a minimum amount of memory usage during processing at the obvious expense of reading and writing the products from disk. - -Outlier Detection for Coronagraphic Data ----------------------------------------- -Coronagraphic data is processed in a near-identical manner to direct imaging data, but -no resampling occurs. - - -Outlier Detection for TSO data -------------------------------- -Normal imaging data benefit from combining all integrations into a -single image. TSO data's value, however, comes from looking for variations from one -integration to the next. The outlier detection algorithm, therefore, gets run with -a few variations to accomodate the nature of these 3D data. See the -:ref:`TSO outlier detection ` documentation for details. - - -Outlier Detection for IFU data ------------------------------- -Integral Field Unit (IFU) data is handled as 2D images, similar to direct -imaging modes. The nature of the detection algorithm, however, is quite -different and involves measuring the differences between neighboring pixels -in the spatial (cross-dispersion) direction within the IFU slice images. -See the :ref:`IFU outlier detection ` documentation for -all the details. - - -Outlier Detection for Slit data -------------------------------- -See the :ref:`IFU outlier detection ` documentation for -details. - .. automodapi:: jwst.outlier_detection.imaging diff --git a/docs/jwst/pipeline/calwebb_image3.rst b/docs/jwst/pipeline/calwebb_image3.rst index 710948c3f5..1bad06273b 100644 --- a/docs/jwst/pipeline/calwebb_image3.rst +++ b/docs/jwst/pipeline/calwebb_image3.rst @@ -34,7 +34,9 @@ processed using the :ref:`calwebb_tso3 ` pipeline. Arguments --------- -The ``calwebb_image3`` pipeline does not have any optional arguments. +``--in_memory`` + Boolean governing whether to load all models in the input association to memory at once (faster) + or to save to temporary files when not in use (slower, less memory usage). Default is True. Inputs ------ diff --git a/docs/jwst/skymatch/arguments.rst b/docs/jwst/skymatch/arguments.rst index f686a90395..a3af4a5918 100644 --- a/docs/jwst/skymatch/arguments.rst +++ b/docs/jwst/skymatch/arguments.rst @@ -67,3 +67,9 @@ The ``skymatch`` step uses the following optional arguments: Bin width, in sigma, used to sample the distribution of pixel values in order to compute the sky background using statistics that require binning, such as `mode` and `midpt`. + +**Memory management parameters:** + +``in_memory`` (boolean, default=True) + If False, preserve memory using temporary files + at the expense of having to run many I/O operations. diff --git a/docs/jwst/tweakreg/README.rst b/docs/jwst/tweakreg/README.rst index 6c6811a86a..ea8069f0c0 100644 --- a/docs/jwst/tweakreg/README.rst +++ b/docs/jwst/tweakreg/README.rst @@ -86,7 +86,7 @@ models to the custom catalog file name, the ``tweakreg_step`` also supports two other ways of supplying custom source catalogs to the step: 1. Adding ``tweakreg_catalog`` attribute to the ``members`` of the input ASN - table - see `~jwst.datamodels.ModelContainer` for more details. + table - see `~jwst.datamodels.ModelLibrary` for more details. Catalog file names are relative to ASN file path. 2. Providing a simple two-column text file, specified via step's parameter @@ -165,17 +165,17 @@ telescope pointing will be identical in all these images and it is assumed that the relative positions of (e.g., NIRCam) detectors do not change. Identification of images that belong to the same "exposure" and therefore can be grouped together is based on several attributes described in -`~jwst.datamodels.ModelContainer`. This grouping is performed automatically +`~jwst.datamodels.ModelLibrary`. This grouping is performed automatically in the ``tweakreg`` step using the -`~jwst.datamodels.ModelContainer.models_grouped` property, which assigns -a group ID to each input image model in ``meta.group_id``. +`~jwst.datamodels.ModelLibrary.group_names` property. + However, when detector calibrations are not accurate, alignment of groups of images may fail (or result in poor alignment). In this case, it may be desirable to align each image independently. This can be achieved either by setting the ``image_model.meta.group_id`` attribute to a unique string or integer value for each image, or by adding the ``group_id`` attribute to the ``members`` of the input ASN -table - see `~jwst.datamodels.ModelContainer` for more details. +table - see `~jwst.datamodels.ModelLibrary` for more details. .. note:: Group ID (``group_id``) is used by both ``tweakreg`` and ``skymatch`` steps @@ -428,6 +428,14 @@ in the ``assign_wcs`` step. * ``sip_npoints``: Number of points for the SIP fit. (Default=12). +**stpipe general options:** + +* ``output_use_model``: A boolean indicating whether to use `DataModel.meta.filename` + when saving the results. (Default=True) + +* ``in_memory``: A boolean indicating whether to keep models in memory, or to save + temporary files on disk while not in use to save memory. (Default=True) + Further Documentation --------------------- The underlying algorithms as well as formats of source catalogs are described diff --git a/jwst/assign_mtwcs/assign_mtwcs_step.py b/jwst/assign_mtwcs/assign_mtwcs_step.py index bd312ccddd..a8e33e46f9 100755 --- a/jwst/assign_mtwcs/assign_mtwcs_step.py +++ b/jwst/assign_mtwcs/assign_mtwcs_step.py @@ -1,9 +1,8 @@ #! /usr/bin/env python import logging -from stdatamodels.jwst import datamodels - -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelLibrary +from jwst.stpipe.utilities import record_step_status from ..stpipe import Step from .moving_target_wcs import assign_moving_target_wcs @@ -32,17 +31,14 @@ class AssignMTWcsStep(Step): """ def process(self, input): - if isinstance(input, str): - input = datamodels.open(input) - - # Can't apply the step if we aren't given a ModelContainer as input - if not isinstance(input, ModelContainer): - log.warning("Input data type is not supported.") - # raise ValueError("Expected input to be an association file name or a ModelContainer.") - input.meta.cal_step.assign_mtwcs = 'SKIPPED' - return input - # Apply the step + if not isinstance(input, ModelLibrary): + try: + input = ModelLibrary(input) + except Exception: + log.warning("Input data type is not supported.") + record_step_status(input, "assign_mtwcs", False) + return input + result = assign_moving_target_wcs(input) - return result diff --git a/jwst/assign_mtwcs/moving_target_wcs.py b/jwst/assign_mtwcs/moving_target_wcs.py index 7f61b84e9c..be4a5a1b8b 100644 --- a/jwst/assign_mtwcs/moving_target_wcs.py +++ b/jwst/assign_mtwcs/moving_target_wcs.py @@ -16,7 +16,8 @@ from stdatamodels.jwst import datamodels -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelLibrary +from jwst.stpipe.utilities import record_step_status log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) @@ -24,49 +25,54 @@ __all__ = ["assign_moving_target_wcs"] -def assign_moving_target_wcs(input_model): +def assign_moving_target_wcs(input_models): - if not isinstance(input_model, ModelContainer): - raise ValueError("Expected a ModelContainer object") + if not isinstance(input_models, ModelLibrary): + raise ValueError("Expected a ModelLibrary object") - # get the indices of the science exposures in the ModelContainer - ind = input_model.ind_asn_type('science') - sci_models = np.asarray(input_model._models)[ind] - # Get the MT RA/Dec values from all the input exposures - mt_ra = np.array([model.meta.wcsinfo.mt_ra for model in sci_models]) - mt_dec = np.array([model.meta.wcsinfo.mt_dec for model in sci_models]) + # loop over only science exposures in the ModelLibrary + ind = input_models.indices_for_exptype("science") + mt_ra = np.empty(len(ind)) + mt_dec = np.empty(len(ind)) + with input_models: + for i in ind: + model = input_models.borrow(i) + mt_ra[i] = model.meta.wcsinfo.mt_ra + mt_dec[i] = model.meta.wcsinfo.mt_dec + input_models.shelve(model, i, modify=False) # Compute the mean MT RA/Dec over all exposures if None in mt_ra or None in mt_dec: log.warning("One or more MT RA/Dec values missing in input images") log.warning("Step will be skipped, resulting in target misalignment") - for model in sci_models: - model.meta.cal_step.assign_mtwcs = 'SKIPPED' - return input_model - else: - mt_avra = mt_ra.mean() - mt_avdec = mt_dec.mean() - - for model in sci_models: - model.meta.wcsinfo.mt_avra = mt_avra - model.meta.wcsinfo.mt_avdec = mt_avdec - if isinstance(model, datamodels.MultiSlitModel): - for ind, slit in enumerate(model.slits): - new_wcs = add_mt_frame(slit.meta.wcs, - mt_avra, mt_avdec, - slit.meta.wcsinfo.mt_ra, slit.meta.wcsinfo.mt_dec) - del model.slits[ind].meta.wcs - model.slits[ind].meta.wcs = new_wcs - else: - - new_wcs = add_mt_frame(model.meta.wcs, mt_avra, mt_avdec, - model.meta.wcsinfo.mt_ra, model.meta.wcsinfo.mt_dec) - del model.meta.wcs - model.meta.wcs = new_wcs - - model.meta.cal_step.assign_mtwcs = 'COMPLETE' - - return input_model + record_step_status(input_models, "assign_mtwcs", False) + return input_models + + mt_avra = mt_ra.mean() + mt_avdec = mt_dec.mean() + + with input_models: + for i in ind: + model = input_models.borrow(i) + model.meta.wcsinfo.mt_avra = mt_avra + model.meta.wcsinfo.mt_avdec = mt_avdec + if isinstance(model, datamodels.MultiSlitModel): + for ind, slit in enumerate(model.slits): + new_wcs = add_mt_frame(slit.meta.wcs, + mt_avra, mt_avdec, + slit.meta.wcsinfo.mt_ra, slit.meta.wcsinfo.mt_dec) + del model.slits[ind].meta.wcs + model.slits[ind].meta.wcs = new_wcs + else: + + new_wcs = add_mt_frame(model.meta.wcs, mt_avra, mt_avdec, + model.meta.wcsinfo.mt_ra, model.meta.wcsinfo.mt_dec) + del model.meta.wcs + model.meta.wcs = new_wcs + record_step_status(model, "assign_mtwcs", True) + input_models.shelve(model, i, modify=True) + + return input_models def add_mt_frame(wcs, ra_average, dec_average, mt_ra, mt_dec): diff --git a/jwst/assign_mtwcs/tests/test_mtwcs.py b/jwst/assign_mtwcs/tests/test_mtwcs.py index a33b848872..934869d4c5 100644 --- a/jwst/assign_mtwcs/tests/test_mtwcs.py +++ b/jwst/assign_mtwcs/tests/test_mtwcs.py @@ -2,7 +2,7 @@ from stdatamodels.jwst import datamodels -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelLibrary from jwst.assign_mtwcs import AssignMTWcsStep from jwst.assign_mtwcs.tests import data @@ -13,10 +13,17 @@ def test_mt_multislit(): file_path = os.path.join(data.__path__[0], 'test_mt_asn.json') with datamodels.open(file_path) as model: assert model[0].slits[0].meta.wcs.output_frame.name == 'world' - step = AssignMTWcsStep() - result = step.run(model) - assert isinstance(result, ModelContainer) - assert len(result[0].slits) == 1 - assert result[0].slits[0].meta.wcs.output_frame.name == 'moving_target' - assert len(result[1].slits) == 1 - assert result[1].slits[0].meta.wcs.output_frame.name == 'moving_target' + step = AssignMTWcsStep() + result = step.run(file_path) + assert isinstance(result, ModelLibrary) + with result: + zero = result.borrow(0) + one = result.borrow(1) + + assert len(zero.slits) == 1 + assert zero.slits[0].meta.wcs.output_frame.name == 'moving_target' + assert len(one.slits) == 1 + assert one.slits[0].meta.wcs.output_frame.name == 'moving_target' + + result.shelve(zero, 0, modify=False) + result.shelve(one, 1, modify=False) diff --git a/jwst/datamodels/__init__.py b/jwst/datamodels/__init__.py index 2bf017f093..1d3e0b23d6 100644 --- a/jwst/datamodels/__init__.py +++ b/jwst/datamodels/__init__.py @@ -8,6 +8,7 @@ from stdatamodels.jwst.datamodels.util import open from .container import ModelContainer +from .library import ModelLibrary from .source_container import SourceModelContainer import stdatamodels.jwst.datamodels @@ -19,14 +20,15 @@ __all__ = [ 'open', 'ModelContainer', 'SourceModelContainer', + 'ModelLibrary', ] + stdatamodels.jwst.datamodels.__all__ # Modules that are not part of stdatamodels -_jwst_modules = ["container", "source_container"] +_jwst_modules = ["container", "source_container", "library"] # Models that are not part of stdatamodels -_jwst_models = ["ModelContainer", "SourceModelContainer"] +_jwst_models = ["ModelContainer", "SourceModelContainer", "ModelLibrary"] # Deprecated modules in stdatamodels _deprecated_modules = ['schema'] diff --git a/jwst/datamodels/library.py b/jwst/datamodels/library.py new file mode 100644 index 0000000000..32e27aa117 --- /dev/null +++ b/jwst/datamodels/library.py @@ -0,0 +1,158 @@ +import io + +import asdf +from astropy.io import fits +from stdatamodels.jwst.datamodels.util import open as datamodels_open +from stpipe.library import AbstractModelLibrary, NoGroupID + +from jwst.associations import AssociationNotValidError, load_asn + +__all__ = ["ModelLibrary"] + + +class ModelLibrary(AbstractModelLibrary): + """ + JWST implementation of the ModelLibrary, a container designed to allow + efficient processing of datamodel instances created from an association. + See the `stpipe library documentation = 1 + + if has_groups: + with input_models: + model = input_models.borrow(0) + is_moving = is_moving_target(model) + input_models.shelve(model, 0, modify=False) + if is_moving: + input_models = self.assign_mtwcs(input_models) + else: + input_models = self.tweakreg(input_models) - # Check if input is single or multiple exposures - try: - has_groups = len(input_models.group_names) >= 1 - except (AttributeError, TypeError, KeyError): - has_groups = False + input_models = self.skymatch(input_models) + input_models = self.outlier_detection(input_models) - if isinstance(input_models, ModelContainer) and has_groups: - if is_moving_target(input_models): - input_models = self.assign_mtwcs(input_models) - else: - input_models = self.tweakreg(input_models) + elif self.skymatch.skymethod == 'match': + self.log.warning("Turning 'skymatch' step off for a single " + "input image when 'skymethod' is 'match'") - input_models = self.skymatch(input_models) - input_models = self.outlier_detection(input_models) + else: + input_models = self.skymatch(input_models) - elif self.skymatch.skymethod == 'match': - self.log.warning("Turning 'skymatch' step off for a single " - "input image when 'skymethod' is 'match'") + result = self.resample(input_models) + del input_models + if isinstance(result, datamodels.ImageModel) and result.meta.cal_step.resample == 'COMPLETE': + self.source_catalog(result) - else: - input_models = self.skymatch(input_models) - result = self.resample(input_models) - if isinstance(result, datamodels.ImageModel) and result.meta.cal_step.resample == 'COMPLETE': - self.source_catalog(result) + def _load_input_as_library(self, input): + """ + Load any valid input type into a ModelLibrary, including + single datamodels, associations, ModelLibrary instances, and + filenames pointing to those types. + """ + + if isinstance(input, ModelLibrary): + return input + + if isinstance(input, (str, dict)): + try: + # Try opening input as an association + return ModelLibrary(input, asn_exptypes=['science'], on_disk=not self.in_memory) + except OSError: + # Try opening input as a single cal file + input = datamodels.open(input) + input = [input,] + return ModelLibrary(input, asn_exptypes=['science'], on_disk=not self.in_memory) + elif isinstance(input, Sequence): + return ModelLibrary(input, asn_exptypes=['science'], on_disk=not self.in_memory) + elif isinstance(input, datamodels.JwstDataModel): + return ModelLibrary([input], asn_exptypes=['science'], on_disk=not self.in_memory) + else: + raise TypeError(f"Input type {type(input)} not supported.") \ No newline at end of file diff --git a/jwst/pipeline/calwebb_spec3.py b/jwst/pipeline/calwebb_spec3.py index 96a5b25d68..7fd69e4373 100644 --- a/jwst/pipeline/calwebb_spec3.py +++ b/jwst/pipeline/calwebb_spec3.py @@ -138,9 +138,10 @@ def process(self, input): for member in product['members']: members_by_type[member['exptype'].lower()].append(member['expname']) - if is_moving_target(input_models): + if is_moving_target(input_models[0]): self.log.info("Assigning WCS to a Moving Target exposure.") - input_models = self.assign_mtwcs(input_models) + # assign_mtwcs modifies input_models in-place + self.assign_mtwcs(input_models) # If background data are present, call the master background step if members_by_type['background']: diff --git a/jwst/pipeline/tests/test_calwebb_image3.py b/jwst/pipeline/tests/test_calwebb_image3.py new file mode 100644 index 0000000000..8531674664 --- /dev/null +++ b/jwst/pipeline/tests/test_calwebb_image3.py @@ -0,0 +1,114 @@ +import pytest +import os +import shutil +from jwst.stpipe import Step +from jwst.assign_wcs import AssignWcsStep +from jwst.datamodels import ImageModel + + +INPUT_FILE = "dummy_cal.fits" +INPUT_FILE_2 = "dummy2_cal.fits" +INPUT_ASN = "dummy_asn.json" +OUTPUT_PRODUCT = "custom_name" +LOGFILE = "run_asn.log" +LOGCFG = "test_logs.cfg" +LOGCFG_CONTENT = f"[*] \n \ + handler = file:{LOGFILE}" + + +@pytest.fixture(scope='module') +def make_dummy_cal_file(tmp_cwd_module): + ''' + Make and save a dummy cal file in the temporary working directory + Partially copied from test_calwebb_image2.py + ''' + + image = ImageModel((2048, 2048)) + image.data[:, :] = 1 + image.meta.instrument.name = 'NIRCAM' + image.meta.instrument.filter = 'F210M' + image.meta.instrument.pupil = 'CLEAR' + image.meta.exposure.type = 'NRC_IMAGE' + image.meta.observation.date = '2024-02-27' + image.meta.observation.time = '13:37:18.548' + image.meta.date = '2024-02-27T13:37:18.548' + image.meta.subarray.xstart = 1 + image.meta.subarray.ystart = 1 + + image.meta.subarray.xsize = image.data.shape[-1] + image.meta.subarray.ysize = image.data.shape[-2] + + image.meta.instrument.channel = 'SHORT' + image.meta.instrument.module = 'A' + image.meta.instrument.detector = 'NRCA1' + + # bare minimum wcs info to get assign_wcs step to pass + image.meta.wcsinfo.crpix1 = 693.5 + image.meta.wcsinfo.crpix2 = 512.5 + image.meta.wcsinfo.v2_ref = -453.37849 + image.meta.wcsinfo.v3_ref = -373.810549 + image.meta.wcsinfo.roll_ref = 272.3237653262276 + image.meta.wcsinfo.ra_ref = 80.54724018120017 + image.meta.wcsinfo.dec_ref = -69.5081101864959 + + image = AssignWcsStep.call(image) + + with image as dm: + dm.save(INPUT_FILE) + + +@pytest.fixture(scope='module') +def make_dummy_association(make_dummy_cal_file): + + shutil.copy(INPUT_FILE, INPUT_FILE_2) + os.system(f"asn_from_list -o {INPUT_ASN} --product-name {OUTPUT_PRODUCT} -r DMS_Level3_Base {INPUT_FILE} {INPUT_FILE_2}") + + +@pytest.mark.parametrize("in_memory", [True, False]) +def test_run_image3_pipeline(make_dummy_association, in_memory): + ''' + Two-product association passed in, run pipeline, skipping most steps + ''' + # save warnings to logfile so can be checked later + with open(LOGCFG, 'w') as f: + f.write(LOGCFG_CONTENT) + + args = ["calwebb_image3", INPUT_ASN, + f"--logcfg={LOGCFG}", + "--steps.tweakreg.skip=true", + "--steps.skymatch.skip=true", + "--steps.outlier_detection.skip=true", + "--steps.resample.skip=true", + "--steps.source_catalog.skip=true", + f"--in_memory={str(in_memory)}",] + + Step.from_cmdline(args) + + _is_run_complete(LOGFILE) + + +def test_run_image3_single_file(make_dummy_cal_file): + + with open(LOGCFG, 'w') as f: + f.write(LOGCFG_CONTENT) + + args = ["calwebb_image3", INPUT_FILE, + f"--logcfg={LOGCFG}", + "--steps.tweakreg.skip=true", + "--steps.skymatch.skip=true", + "--steps.outlier_detection.skip=true", + "--steps.resample.skip=true", + "--steps.source_catalog.skip=true",] + + Step.from_cmdline(args) + _is_run_complete(LOGFILE) + + +def _is_run_complete(logfile): + ''' + Check that the pipeline runs to completion + ''' + msg = "Step Image3Pipeline done" + with open(LOGFILE, 'r') as f: + log = f.read() + assert msg in log diff --git a/jwst/regtest/test_niriss_image.py b/jwst/regtest/test_niriss_image.py index 4cce41eda6..7e186986a2 100644 --- a/jwst/regtest/test_niriss_image.py +++ b/jwst/regtest/test_niriss_image.py @@ -71,11 +71,10 @@ def test_niriss_tweakreg_no_sources(rtdata, fitsdiff_default_kwargs): assert model.meta.cal_step.tweakreg != 'SKIPPED' result = TweakRegStep.call(mc) - - for model in result: - assert model.meta.cal_step.tweakreg == 'SKIPPED' - - result.close() + with result: + for model in result: + assert model.meta.cal_step.tweakreg == 'SKIPPED' + result.shelve(model, modify=False) def _assert_is_same(rtdata_module, fitsdiff_default_kwargs, suffix): diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 681592f3aa..97b2e3d766 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -1,6 +1,7 @@ import logging import os import warnings +import json import numpy as np import psutil @@ -10,7 +11,8 @@ from stdatamodels.jwst import datamodels from stdatamodels.jwst.library.basic_utils import bytes2human -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelLibrary +from jwst.associations.asn_from_list import asn_from_list from . import gwcs_drizzle from jwst.resample import resample_utils @@ -49,8 +51,8 @@ def __init__(self, input_models, output=None, single=False, blendheaders=True, """ Parameters ---------- - input_models : list of objects - list of data models, one for each input image + input_models : library of objects + library of data models, one for each input image output : str filename for output @@ -133,7 +135,6 @@ def __init__(self, input_models, output=None, single=False, blendheaders=True, crpix=crpix, crval=crval ) - # Estimate output pixel area in Sr. NOTE: in principle we could # use the same algorithm as for when output_wcs is provided by the # user. @@ -180,14 +181,17 @@ def __init__(self, input_models, output=None, single=False, blendheaders=True, self.blank_output = datamodels.ImageModel(tuple(self.output_wcs.array_shape)) # update meta data and wcs - self.blank_output.update(input_models[0]) + with input_models: + example_model = input_models.borrow(0) + self.blank_output.update(example_model) + input_models.shelve(example_model, 0, modify=False) + del example_model self.blank_output.meta.wcs = self.output_wcs self.blank_output.meta.photometry.pixelarea_steradians = output_pix_area self.blank_output.meta.photometry.pixelarea_arcsecsq = ( output_pix_area * np.rad2deg(3600)**2 ) - self.output_models = ModelContainer(open_models=False) def do_drizzle(self, input_models): """Pick the correct drizzling mode based on self.single @@ -275,64 +279,77 @@ def resample_many_to_many(self, input_models): Used for outlier detection """ - for exposure in input_models.models_grouped: + output_models = [] + for group_id, indices in input_models.group_indices.items(): output_model = self.blank_output - # Determine output file type from input exposure filenames - # Use this for defining the output filename - indx = exposure[0].meta.filename.rfind('.') - output_type = exposure[0].meta.filename[indx:] - output_root = '_'.join(exposure[0].meta.filename.replace( - output_type, '').split('_')[:-1]) - if self.asn_id is not None: - output_model.meta.filename = ( - f'{output_root}_{self.asn_id}_' - f'{self.intermediate_suffix}{output_type}') - else: - output_model.meta.filename = ( - f'{output_root}_' - f'{self.intermediate_suffix}{output_type}') - - # Initialize the output with the wcs - driz = gwcs_drizzle.GWCSDrizzle(output_model, pixfrac=self.pixfrac, - kernel=self.kernel, fillval=self.fillval) - - log.info(f"{len(exposure)} exposures to drizzle together") - for img in exposure: - img = datamodels.open(img) - iscale = self._get_intensity_scale(img) - log.debug(f'Using intensity scale iscale={iscale}') - - inwht = resample_utils.build_driz_weight( - img, - weight_type=self.weight_type, - good_bits=self.good_bits - ) - - # apply sky subtraction - blevel = img.meta.background.level - if not img.meta.background.subtracted and blevel is not None: - data = img.data - blevel + copy_asn_info_from_library(input_models, output_model) + + with input_models: + example_image = input_models.borrow(indices[0]) + + # Determine output file type from input exposure filenames + # Use this for defining the output filename + indx = example_image.meta.filename.rfind('.') + output_type = example_image.meta.filename[indx:] + output_root = '_'.join(example_image.meta.filename.replace( + output_type, '').split('_')[:-1]) + if self.asn_id is not None: + output_model.meta.filename = ( + f'{output_root}_{self.asn_id}_' + f'{self.intermediate_suffix}{output_type}') else: - data = img.data + output_model.meta.filename = ( + f'{output_root}_' + f'{self.intermediate_suffix}{output_type}') + input_models.shelve(example_image, indices[0], modify=False) + del example_image + + # Initialize the output with the wcs + driz = gwcs_drizzle.GWCSDrizzle(output_model, pixfrac=self.pixfrac, + kernel=self.kernel, fillval=self.fillval) + + log.info(f"{len(indices)} exposures to drizzle together") + for index in indices: + img = input_models.borrow(index) + if isinstance(img, datamodels.SlitModel): + # must call this explicitly to populate area extension + # although the existence of this extension may not be necessary + img.area = img.area + iscale = self._get_intensity_scale(img) + log.debug(f'Using intensity scale iscale={iscale}') + + inwht = resample_utils.build_driz_weight( + img, + weight_type=self.weight_type, + good_bits=self.good_bits + ) - xmin, xmax, ymin, ymax = resample_utils._resample_range( - data.shape, - img.meta.wcs.bounding_box - ) + # apply sky subtraction + blevel = img.meta.background.level + if not img.meta.background.subtracted and blevel is not None: + data = img.data - blevel + else: + data = img.data - driz.add_image( - data, - img.meta.wcs, - iscale=iscale, - inwht=inwht, - xmin=xmin, - xmax=xmax, - ymin=ymin, - ymax=ymax - ) - del data - img.close() + xmin, xmax, ymin, ymax = resample_utils._resample_range( + data.shape, + img.meta.wcs.bounding_box + ) + + driz.add_image( + data, + img.meta.wcs, + iscale=iscale, + inwht=inwht, + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax + ) + del data + input_models.shelve(img, index, modify=False) + del img if not self.in_memory: # Write out model to disk, then return filename @@ -341,13 +358,20 @@ def resample_many_to_many(self, input_models): output_name = os.path.join(self.output_dir, output_name) output_model.save(output_name) log.info(f"Saved model in {output_name}") - self.output_models.append(output_name) + output_models.append(output_name) else: - self.output_models.append(output_model.copy()) + output_models.append(output_model.copy()) output_model.data *= 0. output_model.wht *= 0. - return self.output_models + if not self.in_memory: + # build ModelLibrary as an association from the output files + # this saves memory if there are multiple groups + asn = asn_from_list(output_models, product_name='outlier_i2d') + asn_dict = json.loads(asn.dump()[1]) # serializes the asn and converts to dict + return ModelLibrary(asn_dict, on_disk=True) + # otherwise just build it as a list of in-memory models + return ModelLibrary(output_models, on_disk=False) def resample_many_to_one(self, input_models): """Resample and coadd many inputs to a single output. @@ -359,45 +383,71 @@ def resample_many_to_one(self, input_models): output_model.meta.resample.weight_type = self.weight_type output_model.meta.resample.pointings = len(input_models.group_names) + # copy over asn information + copy_asn_info_from_library(input_models, output_model) + if self.blendheaders: - self.blend_output_metadata(output_model, input_models) + # right now this needs a list of input models, all in memory + # for now, just load the models as a list with empty data arrays + # but the blend_meta step itself should eventually be refactored + # to expect a list of metadata objects + # instead of a list of datamodels + input_list = [] + with input_models: + for i, model in enumerate(input_models): + empty_model = type(model)() + empty_model.meta = model.meta + copy_asn_info_from_library(input_models, empty_model) + empty_model.data = np.empty((1, 1)) + empty_model.dq = np.empty((1, 1)) + empty_model.err = np.empty((1, 1)) + empty_model.wht = np.empty((1, 1)) + empty_model.var_rnoise = np.empty((1, 1)) + empty_model.var_poisson = np.empty((1, 1)) + empty_model.var_flat = np.empty((1, 1)) + input_list.append(empty_model) + input_models.shelve(model, i, modify=False) + self.blend_output_metadata(output_model, input_list) + del input_list # Initialize the output with the wcs driz = gwcs_drizzle.GWCSDrizzle(output_model, pixfrac=self.pixfrac, kernel=self.kernel, fillval=self.fillval) log.info("Resampling science data") - for img in input_models: - iscale = self._get_intensity_scale(img) - log.debug(f'Using intensity scale iscale={iscale}') - img.meta.iscale = iscale - - inwht = resample_utils.build_driz_weight(img, - weight_type=self.weight_type, - good_bits=self.good_bits) - # apply sky subtraction - blevel = img.meta.background.level - if not img.meta.background.subtracted and blevel is not None: - data = img.data - blevel - else: - data = img.data.copy() + with input_models: + for img in input_models: + iscale = self._get_intensity_scale(img) + log.debug(f'Using intensity scale iscale={iscale}') + img.meta.iscale = iscale - xmin, xmax, ymin, ymax = resample_utils._resample_range( - data.shape, - img.meta.wcs.bounding_box - ) + inwht = resample_utils.build_driz_weight(img, + weight_type=self.weight_type, + good_bits=self.good_bits) + # apply sky subtraction + blevel = img.meta.background.level + if not img.meta.background.subtracted and blevel is not None: + data = img.data - blevel + else: + data = img.data.copy() - driz.add_image( - data, - img.meta.wcs, - iscale=iscale, - inwht=inwht, - xmin=xmin, - xmax=xmax, - ymin=ymin, - ymax=ymax - ) - del data, inwht + xmin, xmax, ymin, ymax = resample_utils._resample_range( + data.shape, + img.meta.wcs.bounding_box + ) + + driz.add_image( + data, + img.meta.wcs, + iscale=iscale, + inwht=inwht, + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax + ) + del data, inwht + input_models.shelve(img) # Resample variance arrays in input_models to output_model self.resample_variance_arrays(output_model, input_models) @@ -414,12 +464,9 @@ def resample_many_to_one(self, input_models): output_model.err[all_nan] = np.nan self.update_exposure_times(output_model, input_models) - self.output_models.append(output_model) - for img in input_models: - del img.meta.iscale + return ModelLibrary([output_model,], on_disk=False) - return self.output_models def resample_variance_arrays(self, output_model, input_models): """Resample variance arrays from input_models to the output_model. @@ -438,82 +485,91 @@ def resample_variance_arrays(self, output_model, input_models): total_weight_rn_var = np.zeros_like(output_model.data) total_weight_pn_var = np.zeros_like(output_model.data) total_weight_flat_var = np.zeros_like(output_model.data) - for model in input_models: - # Do the read noise variance first, so it can be - # used for weights if needed - rn_var = self._resample_one_variance_array( - "var_rnoise", model, output_model) - - # Find valid weighting values in the variance - if rn_var is not None: - mask = (rn_var > 0) & np.isfinite(rn_var) - else: - mask = np.full_like(rn_var, False) - - # Set the weight for the image from the weight type - weight = np.ones(output_model.data.shape) - if self.weight_type == "ivm" and rn_var is not None: - weight[mask] = rn_var[mask] ** -1 - elif self.weight_type == "exptime": - if resample_utils.check_for_tmeasure(model): - weight[:] = model.meta.exposure.measurement_time + with input_models: + for i, model in enumerate(input_models): + # Do the read noise variance first, so it can be + # used for weights if needed + rn_var = self._resample_one_variance_array( + "var_rnoise", model, output_model) + + # Find valid weighting values in the variance + if rn_var is not None: + mask = (rn_var > 0) & np.isfinite(rn_var) else: - weight[:] = model.meta.exposure.exposure_time - - # Weight and add the readnoise variance - # Note: floating point overflow is an issue if variance weights - # are used - it can't be squared before multiplication - if rn_var is not None: - mask = (rn_var >= 0) & np.isfinite(rn_var) & (weight > 0) - weighted_rn_var[mask] = np.nansum( - [weighted_rn_var[mask], - rn_var[mask] * weight[mask] * weight[mask]], - axis=0 - ) - total_weight_rn_var[mask] += weight[mask] - - # Now do poisson and flat variance, updating only valid new values - # (zero is a valid value; negative, inf, or NaN are not) - pn_var = self._resample_one_variance_array( - "var_poisson", model, output_model) - if pn_var is not None: - mask = (pn_var >= 0) & np.isfinite(pn_var) & (weight > 0) - weighted_pn_var[mask] = np.nansum( - [weighted_pn_var[mask], - pn_var[mask] * weight[mask] * weight[mask]], - axis=0 - ) - total_weight_pn_var[mask] += weight[mask] - - flat_var = self._resample_one_variance_array( - "var_flat", model, output_model) - if flat_var is not None: - mask = (flat_var >= 0) & np.isfinite(flat_var) & (weight > 0) - weighted_flat_var[mask] = np.nansum( - [weighted_flat_var[mask], - flat_var[mask] * weight[mask] * weight[mask]], - axis=0 - ) - total_weight_flat_var[mask] += weight[mask] + mask = np.full_like(rn_var, False) + + # Set the weight for the image from the weight type + weight = np.ones(output_model.data.shape) + if self.weight_type == "ivm" and rn_var is not None: + weight[mask] = rn_var[mask] ** -1 + elif self.weight_type == "exptime": + if resample_utils.check_for_tmeasure(model): + weight[:] = model.meta.exposure.measurement_time + else: + weight[:] = model.meta.exposure.exposure_time + + # Weight and add the readnoise variance + # Note: floating point overflow is an issue if variance weights + # are used - it can't be squared before multiplication + if rn_var is not None: + mask = (rn_var >= 0) & np.isfinite(rn_var) & (weight > 0) + weighted_rn_var[mask] = np.nansum( + [weighted_rn_var[mask], + rn_var[mask] * weight[mask] * weight[mask]], + axis=0 + ) + total_weight_rn_var[mask] += weight[mask] + + # Now do poisson and flat variance, updating only valid new values + # (zero is a valid value; negative, inf, or NaN are not) + pn_var = self._resample_one_variance_array( + "var_poisson", model, output_model) + if pn_var is not None: + mask = (pn_var >= 0) & np.isfinite(pn_var) & (weight > 0) + weighted_pn_var[mask] = np.nansum( + [weighted_pn_var[mask], + pn_var[mask] * weight[mask] * weight[mask]], + axis=0 + ) + total_weight_pn_var[mask] += weight[mask] + + flat_var = self._resample_one_variance_array( + "var_flat", model, output_model) + if flat_var is not None: + mask = (flat_var >= 0) & np.isfinite(flat_var) & (weight > 0) + weighted_flat_var[mask] = np.nansum( + [weighted_flat_var[mask], + flat_var[mask] * weight[mask] * weight[mask]], + axis=0 + ) + total_weight_flat_var[mask] += weight[mask] + + del model.meta.iscale + del weight + input_models.shelve(model, i) + + # We now have a sum of the weighted resampled variances. + # Divide by the total weights, squared, and set in the output model. + # Zero weight and missing values are NaN in the output. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) + warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) + + output_variance = (weighted_rn_var + / total_weight_rn_var / total_weight_rn_var) + setattr(output_model, "var_rnoise", output_variance) - # We now have a sum of the weighted resampled variances. - # Divide by the total weights, squared, and set in the output model. - # Zero weight and missing values are NaN in the output. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "invalid value*", RuntimeWarning) - warnings.filterwarnings("ignore", "divide by zero*", RuntimeWarning) + output_variance = (weighted_pn_var + / total_weight_pn_var / total_weight_pn_var) + setattr(output_model, "var_poisson", output_variance) - output_variance = (weighted_rn_var - / total_weight_rn_var / total_weight_rn_var) - setattr(output_model, "var_rnoise", output_variance) + output_variance = (weighted_flat_var + / total_weight_flat_var / total_weight_flat_var) + setattr(output_model, "var_flat", output_variance) - output_variance = (weighted_pn_var - / total_weight_pn_var / total_weight_pn_var) - setattr(output_model, "var_poisson", output_variance) + del weighted_rn_var, weighted_pn_var, weighted_flat_var + del total_weight_rn_var, total_weight_pn_var, total_weight_flat_var - output_variance = (weighted_flat_var - / total_weight_flat_var / total_weight_flat_var) - setattr(output_model, "var_flat", output_variance) def _resample_one_variance_array(self, name, input_model, output_model): """Resample one variance image from an input model. @@ -583,16 +639,19 @@ def update_exposure_times(self, output_model, input_models): duration = 0.0 total_measurement_time = 0.0 measurement_time_failures = [] - for exposure in input_models.models_grouped: - total_exposure_time += exposure[0].meta.exposure.exposure_time - if not resample_utils.check_for_tmeasure(exposure[0]): - measurement_time_failures.append(1) - else: - total_measurement_time += exposure[0].meta.exposure.measurement_time - measurement_time_failures.append(0) - exposure_times['start'].append(exposure[0].meta.exposure.start_time) - exposure_times['end'].append(exposure[0].meta.exposure.end_time) - duration += exposure[0].meta.exposure.duration + with input_models: + for _, indices in input_models.group_indices.items(): + model = input_models.borrow(indices[0]) + total_exposure_time += model.meta.exposure.exposure_time + if not resample_utils.check_for_tmeasure(model): + measurement_time_failures.append(1) + else: + total_measurement_time += model.meta.exposure.measurement_time + measurement_time_failures.append(0) + exposure_times['start'].append(model.meta.exposure.start_time) + exposure_times['end'].append(model.meta.exposure.end_time) + duration += model.meta.exposure.duration + input_models.shelve(model, indices[0], modify=False) # Update some basic exposure time values based on output_model output_model.meta.exposure.exposure_time = total_exposure_time @@ -833,7 +892,6 @@ def compute_image_pixel_area(wcs): spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == 'SPATIAL')[0] ny, nx = wcs.array_shape - ((xmin, xmax), (ymin, ymax)) = wcs.bounding_box xmin = max(0, int(xmin + 0.5)) @@ -897,3 +955,28 @@ def compute_image_pixel_area(wcs): pix_area = sky_area / image_area return pix_area + + +def copy_asn_info_from_library(library, output_model): + """ + Transfer association information from the input library to the output model. + + Parameters + ---------- + library : ModelLibrary + The input library of data models. + + output_model : DataModel + The output data model to which the association information will be copied. + """ + if not hasattr(library, "asn"): + # No ASN table, occurs when input comes from ModelContainer in spectroscopic modes + # in this case do nothing; the asn info will be passed along later + # by code inside ResampleSpecStep + return + if (asn_pool := library.asn.get("asn_pool", None)) is not None: + output_model.meta.asn.pool_name = asn_pool + if ( + asn_table_name := library.asn.get("table_name", None) + ) is not None: + output_model.meta.asn.table_name = asn_table_name \ No newline at end of file diff --git a/jwst/resample/resample_spec.py b/jwst/resample/resample_spec.py index e75d1c898d..9316a838c0 100644 --- a/jwst/resample/resample_spec.py +++ b/jwst/resample/resample_spec.py @@ -16,7 +16,6 @@ from stdatamodels.jwst import datamodels from jwst.assign_wcs.util import compute_scale, wrap_ra -from jwst.datamodels import ModelContainer from jwst.resample import resample_utils from jwst.resample.resample import ResampleData @@ -192,7 +191,6 @@ def __init__(self, input_models, output=None, single=False, blendheaders=False, self.blank_output.meta.photometry.pixelarea_arcsecsq = ( output_pix_area * np.rad2deg(3600)**2) - self.output_models = ModelContainer() def build_nirspec_output_wcs(self, input_models, refmodel=None): """ @@ -946,4 +944,4 @@ def compute_spectral_pixel_scale(wcs, fiducial=None, disp_axis=1): fiducial = wcs(center_x, center_y) pixel_scale = compute_scale(wcs, fiducial, disp_axis=disp_axis) - return float(pixel_scale) + return float(pixel_scale) \ No newline at end of file diff --git a/jwst/resample/resample_spec_step.py b/jwst/resample/resample_spec_step.py index b529f5690b..bcc3d1a95c 100755 --- a/jwst/resample/resample_spec_step.py +++ b/jwst/resample/resample_spec_step.py @@ -3,11 +3,11 @@ from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels import MultiSlitModel, ImageModel +from jwst.datamodels import ModelContainer, ModelLibrary from . import resample_spec, ResampleStep from ..exp_to_source import multislit_to_container from ..assign_wcs.util import update_s_region_spectral from ..stpipe import Step -from jwst.datamodels import ModelContainer from jwst.lib.wcs_utils import get_wavelengths @@ -95,6 +95,7 @@ def process(self, input): result = self._process_slit(input_models) # Update ASNTABLE in output + result.meta.cal_step.resample = "COMPLETE" result.meta.asn.table_name = input_models[0].meta.asn.table_name result.meta.asn.pool_name = input_models[0].meta.asn.pool_name @@ -133,20 +134,20 @@ def _process_multislit(self, input_models): for container in containers.values(): resamp = resample_spec.ResampleSpecData(container, **self.drizpars) - drizzled_models = resamp.do_drizzle(container) - - for model in drizzled_models: - self.update_slit_metadata(model) - update_s_region_spectral(model) - result.slits.append(model) + library = ModelLibrary(container, on_disk=False) + drizzled_library = resamp.do_drizzle(library) + with drizzled_library: + for i, model in enumerate(drizzled_library): + self.update_slit_metadata(model) + update_s_region_spectral(model) + result.slits.append(model) + drizzled_library.shelve(model, i, modify=False) + del library, drizzled_library # Keep the first computed pixel scale ratio for storage if self.pixel_scale is not None and pscale_ratio is None: pscale_ratio = resamp.pscale_ratio - result.meta.cal_step.resample = "COMPLETE" - result.meta.asn.pool_name = input_models.asn_pool_name - result.meta.asn.table_name = input_models.asn_table_name if self.pixel_scale is None or pscale_ratio is None: result.meta.resample.pixel_scale_ratio = self.pixel_scale_ratio else: @@ -206,13 +207,14 @@ def _process_slit(self, input_models): resamp = resample_spec.ResampleSpecData(input_models, **self.drizpars) - drizzled_models = resamp.do_drizzle(input_models) + library = ModelLibrary(input_models, on_disk=False) + drizzled_library = resamp.do_drizzle(library) + with drizzled_library: + result = drizzled_library.borrow(0) + drizzled_library.shelve(result, 0, modify=False) + del library, drizzled_library - result = drizzled_models[0] - result.meta.cal_step.resample = "COMPLETE" - result.meta.asn.pool_name = input_models.asn_pool_name - result.meta.asn.table_name = input_models.asn_table_name - result.meta.bunit_data = drizzled_models[0].meta.bunit_data + result.meta.bunit_data = input_models[0].meta.bunit_data if self.pixel_scale is None: result.meta.resample.pixel_scale_ratio = self.pixel_scale_ratio else: @@ -241,4 +243,4 @@ def update_slit_metadata(self, model): pass else: if val is not None: - setattr(model, attr, val) + setattr(model, attr, val) \ No newline at end of file diff --git a/jwst/resample/resample_step.py b/jwst/resample/resample_step.py index 4c2d647985..a2fd35db96 100755 --- a/jwst/resample/resample_step.py +++ b/jwst/resample/resample_step.py @@ -4,9 +4,7 @@ import asdf -from stdatamodels.jwst import datamodels - -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelLibrary, ImageModel from . import resample from ..stpipe import Step @@ -61,28 +59,34 @@ class ResampleStep(Step): def process(self, input): - input = datamodels.open(input) - - if isinstance(input, ModelContainer): + if isinstance(input, ModelLibrary): input_models = input - try: - output = input_models.meta.asn_table.products[0].name - except AttributeError: - # coron data goes through this path by the time it gets to - # resampling. - # TODO: figure out why and make sure asn_table is carried along - output = None - else: - input_models = ModelContainer([input]) - input_models.asn_pool_name = input.meta.asn.pool_name - input_models.asn_table_name = input.meta.asn.table_name + elif isinstance(input, (str, dict, list)): + input_models = ModelLibrary(input, on_disk=not self.in_memory) + elif isinstance(input, ImageModel): + input_models = ModelLibrary([input], on_disk=not self.in_memory) output = input.meta.filename self.blendheaders = False + else: + raise RuntimeError(f"Input {input} is not a 2D image.") + + try: + output = input_models.asn["products"][0]["name"] + except KeyError: + # coron data goes through this path by the time it gets to + # resampling. + # TODO: figure out why and make sure asn_table is carried along + output = None # Check that input models are 2D images - if len(input_models[0].data.shape) != 2: - # resample can only handle 2D images, not 3D cubes, etc - raise RuntimeError("Input {} is not a 2D image.".format(input_models[0])) + with input_models: + example_model = input_models.borrow(0) + data_shape = example_model.data.shape + input_models.shelve(example_model, 0, modify=False) + if len(data_shape) != 2: + # resample can only handle 2D images, not 3D cubes, etc + raise RuntimeError(f"Input {example_model} is not a 2D image.") + del example_model # Setup drizzle-related parameters kwargs = self.get_drizpars() @@ -91,26 +95,27 @@ def process(self, input): resamp = resample.ResampleData(input_models, output=output, **kwargs) result = resamp.do_drizzle(input_models) - for model in result: - model.meta.cal_step.resample = 'COMPLETE' - self.update_fits_wcs(model) - util.update_s_region_imaging(model) - model.meta.asn.pool_name = input_models.asn_pool_name - model.meta.asn.table_name = input_models.asn_table_name - - # if pixel_scale exists, it will override pixel_scale_ratio. - # calculate the actual value of pixel_scale_ratio based on pixel_scale - # because source_catalog uses this value from the header. - if self.pixel_scale is None: - model.meta.resample.pixel_scale_ratio = self.pixel_scale_ratio - else: - model.meta.resample.pixel_scale_ratio = resamp.pscale_ratio - model.meta.resample.pixfrac = kwargs['pixfrac'] - - if len(result) == 1: - result = result[0] - - input_models.close() + with result: + for model in result: + model.meta.cal_step.resample = 'COMPLETE' + self.update_fits_wcs(model) + util.update_s_region_imaging(model) + + # if pixel_scale exists, it will override pixel_scale_ratio. + # calculate the actual value of pixel_scale_ratio based on pixel_scale + # because source_catalog uses this value from the header. + if self.pixel_scale is None: + model.meta.resample.pixel_scale_ratio = self.pixel_scale_ratio + else: + model.meta.resample.pixel_scale_ratio = resamp.pscale_ratio + model.meta.resample.pixfrac = kwargs['pixfrac'] + result.shelve(model) + + if len(result) == 1: + model = result.borrow(0) + result.shelve(model, 0, modify=False) + return model + return result @staticmethod diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index 0ba20c62c9..c527b450f1 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -27,7 +27,7 @@ def make_output_wcs(input_models, ref_wcs=None, Parameters ---------- - input_models : list of `~jwst.datamodel.JwstDataModel` + input_models : `~jwst.datamodel.ModelLibrary` Each datamodel must have a ~gwcs.WCS object. pscale_ratio : float, optional @@ -67,10 +67,16 @@ def make_output_wcs(input_models, ref_wcs=None, WCS object, with defined domain, covering entire set of input frames """ if ref_wcs is None: - wcslist = [i.meta.wcs for i in input_models] - for w, i in zip(wcslist, input_models): - if w.bounding_box is None: - w.bounding_box = wcs_bbox_from_shape(i.data.shape) + with input_models: + wcslist = [] + for i, model in enumerate(input_models): + w = model.meta.wcs + if w.bounding_box is None: + w.bounding_box = wcs_bbox_from_shape(model.data.shape) + wcslist.append(w) + if i == 0: + example_model = model + input_models.shelve(model) naxes = wcslist[0].output_frame.naxes if naxes != 2: @@ -81,7 +87,7 @@ def make_output_wcs(input_models, ref_wcs=None, output_wcs = util.wcs_from_footprints( wcslist, ref_wcs=wcslist[0], - ref_wcsinfo=input_models[0].meta.wcsinfo.instance, + ref_wcsinfo=example_model.meta.wcsinfo.instance, pscale_ratio=pscale_ratio, pscale=pscale, rotation=rotation, @@ -89,6 +95,7 @@ def make_output_wcs(input_models, ref_wcs=None, crpix=crpix, crval=crval ) + del example_model else: naxes = ref_wcs.output_frame.naxes diff --git a/jwst/resample/tests/test_resample_step.py b/jwst/resample/tests/test_resample_step.py index fcd380a4b8..75eddb8530 100644 --- a/jwst/resample/tests/test_resample_step.py +++ b/jwst/resample/tests/test_resample_step.py @@ -7,7 +7,7 @@ from stdatamodels.jwst.datamodels import ImageModel -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelContainer, ModelLibrary from jwst.assign_wcs import AssignWcsStep from jwst.assign_wcs.util import compute_fiducial, compute_scale from jwst.exp_to_source import multislit_to_container @@ -637,7 +637,7 @@ def test_weight_type(nircam_rate, tmp_cwd): im2.meta.observation.sequence_id = "2" im3.meta.observation.sequence_id = "3" - c = ModelContainer([im1, im2, im3]) + c = ModelLibrary([im1, im2, im3]) assert len(c.group_names) == 3 result1 = ResampleStep.call(c, weight_type="ivm", blendheaders=False, save_results=True) @@ -656,8 +656,10 @@ def test_weight_type(nircam_rate, tmp_cwd): # remove measurement time to force use of exposure time # this also implicitly shows that measurement time was indeed used above expected_ratio = im1.meta.exposure.exposure_time / im1.meta.exposure.measurement_time - for im in c: - del im.meta.exposure.measurement_time + with c: + for j, im in enumerate(c): + del im.meta.exposure.measurement_time + c.shelve(im, j) result3 = ResampleStep.call(c, weight_type="exptime", blendheaders=False) assert_allclose(result3.data[100:105, 100:105], 6.667, rtol=1e-2) @@ -791,9 +793,7 @@ def test_resample_variance(nircam_rate, n_images, weight_type): im.err += err im.meta.filename = "foo.fits" - c = ModelContainer() - for n in range(n_images): - c.append(im.copy()) + c = ModelLibrary([im.copy() for _ in range(n_images)]) result = ResampleStep.call(c, blendheaders=False, weight_type=weight_type) @@ -814,7 +814,7 @@ def test_resample_undefined_variance(nircam_rate, shape): im.var_poisson = np.ones(shape, dtype=im.var_poisson.dtype.type) im.var_flat = np.ones(shape, dtype=im.var_flat.dtype.type) im.meta.filename = "foo.fits" - c = ModelContainer([im]) + c = ModelLibrary([im]) with pytest.warns(RuntimeWarning, match="var_rnoise array not available"): result = ResampleStep.call(c, blendheaders=False) diff --git a/jwst/skymatch/skymatch_step.py b/jwst/skymatch/skymatch_step.py index 9723e0ef4d..c64b63c00d 100644 --- a/jwst/skymatch/skymatch_step.py +++ b/jwst/skymatch/skymatch_step.py @@ -18,12 +18,8 @@ ) from stdatamodels.jwst.datamodels.dqflags import pixel -from stdatamodels.jwst.datamodels.util import ( - open as datamodel_open, - is_association -) -from jwst.datamodels import ModelContainer +from jwst.datamodels import ModelLibrary from ..stpipe import Step @@ -62,33 +58,23 @@ class SkyMatchStep(Step): lsigma = float(min=0.0, default=4.0) # Lower clipping limit, in sigma usigma = float(min=0.0, default=4.0) # Upper clipping limit, in sigma binwidth = float(min=0.0, default=0.1) # Bin width for 'mode' and 'midpt' `skystat`, in sigma + + # Memory management: + in_memory = boolean(default=True) # If False, preserve memory using temporary files """ # noqa: E501 reference_file_types = [] def __init__(self, *args, **kwargs): - minimize_memory = kwargs.pop('minimize_memory', False) super().__init__(*args, **kwargs) - self.minimize_memory = minimize_memory def process(self, input): self.log.setLevel(logging.DEBUG) - # for now turn off memory optimization until we have better machinery - # to handle outputs in a consistent way. - - if hasattr(self, 'minimize_memory') and self.minimize_memory: - self._is_asn = ( - is_association(input) or isinstance(input, str) - ) + if isinstance(input, ModelLibrary): + library = input else: - self._is_asn = False - - img = ModelContainer( - input, - save_open=not self._is_asn, - return_open=not self._is_asn - ) + library = ModelLibrary(input, on_disk=not self.in_memory) self._dqbits = interpret_bit_flags(self.dqbits, flag_name_map=pixel) @@ -103,51 +89,45 @@ def process(self, input): binwidth=self.binwidth ) - # group images by their "group id": - grp_img = img.models_grouped - - # create a list of "Sky" Images and/or Groups: images = [] - grp_id = 1 - - for g in grp_img: - if len(g) > 1: - images.append( - SkyGroup( - list(map(self._imodel2skyim, g)), - id=grp_id - ) - ) - grp_id += 1 - elif len(g) == 1: - images.append(self._imodel2skyim(g[0])) - else: - raise AssertionError("Logical error in the pipeline code.") + with library: + for group_index, (group_id, group_inds) in enumerate(library.group_indices.items()): + sky_images = [] + for index in group_inds: + model = library.borrow(index) + try: + sky_images.append(self._imodel2skyim(model, index)) + finally: + library.shelve(model, index, modify=False) + if len(sky_images) == 1: + images.extend(sky_images) + else: + images.append(SkyGroup(sky_images, id=group_index)) # match/compute sky values: match(images, skymethod=self.skymethod, match_down=self.match_down, subtract=self.subtract) # set sky background value in each image's meta: - for im in images: - if isinstance(im, SkyImage): - self._set_sky_background( - im, - "COMPLETE" if im.is_sky_valid else "SKIPPED" - ) - else: - for gim in im: + with library: + for im in images: + if isinstance(im, SkyImage): self._set_sky_background( - gim, - "COMPLETE" if gim.is_sky_valid else "SKIPPED" + im, + library, + "COMPLETE" if im.is_sky_valid else "SKIPPED" ) + else: + for gim in im: + self._set_sky_background( + gim, + library, + "COMPLETE" if gim.is_sky_valid else "SKIPPED" + ) - return input if self._is_asn else img + return library - def _imodel2skyim(self, image_model): - input_image_model = image_model - if self._is_asn: - image_model = datamodel_open(image_model) + def _imodel2skyim(self, image_model, index): if self._dqbits is None: dqmask = np.isfinite(image_model.data).astype(dtype=np.uint8) @@ -163,9 +143,6 @@ def _imodel2skyim(self, image_model): # if 'subtract' mode has changed compared to the previous pass: if image_model.meta.background.subtracted is None: if image_model.meta.background.level is not None: - if self._is_asn: - image_model.close() - # report inconsistency: raise ValueError("Background level was set but the " "'subtracted' property is undefined (None).") @@ -179,9 +156,6 @@ def _imodel2skyim(self, image_model): # at this moment I think it is saver to quit and... # # report inconsistency: - if self._is_asn: - image_model.close() - raise ValueError("Background level was subtracted but the " "'level' property is undefined (None).") @@ -189,9 +163,6 @@ def _imodel2skyim(self, image_model): # cannot run 'skymatch' step on already "skymatched" images # when 'subtract' spec is inconsistent with # meta.background.subtracted: - if self._is_asn: - image_model.close() - raise ValueError("'subtract' step's specification is " "inconsistent with background info already " "present in image '{:s}' meta." @@ -206,30 +177,36 @@ def _imodel2skyim(self, image_model): pix_area=1.0, # TODO: pixel area convf=1.0, # TODO: conv. factor to brightness mask=dqmask, - id=image_model.meta.filename, # file name? + id=image_model.meta.filename, skystat=self._skystat, stepsize=self.stepsize, - reduce_memory_usage=self._is_asn, - meta={'image_model': input_image_model} + reduce_memory_usage=False, # this overwrote input files + meta={'index': index} ) - if self._is_asn: - image_model.close() - if self.subtract: sky_im.sky = level return sky_im - def _set_sky_background(self, sky_image, step_status): - image = sky_image.meta['image_model'] + def _set_sky_background(self, sky_image, library, step_status): + """ + Parameters + ---------- + sky_image : SkyImage + SkyImage object containing sky image data and metadata. + + library : ModelLibrary + Library of input data models, must be open + + step_status : str + Status of the sky subtraction step. Must be one of the following: + 'COMPLETE', 'SKIPPED'. + """ + index = sky_image.meta['index'] + dm = library.borrow(index) sky = sky_image.sky - if self._is_asn: - dm = datamodel_open(image) - else: - dm = image - if step_status == "COMPLETE": dm.meta.background.method = str(self.skymethod) dm.meta.background.level = sky @@ -238,7 +215,4 @@ def _set_sky_background(self, sky_image, step_status): dm.data[...] = sky_image.image[...] dm.meta.cal_step.skymatch = step_status - - if self._is_asn: - dm.save(image) - dm.close() + library.shelve(dm, index) diff --git a/jwst/skymatch/tests/test_skymatch.py b/jwst/skymatch/tests/test_skymatch.py index 774b04e2b7..2442819ffc 100644 --- a/jwst/skymatch/tests/test_skymatch.py +++ b/jwst/skymatch/tests/test_skymatch.py @@ -220,8 +220,10 @@ def test_skymatch(nircam_rate, skymethod, subtract, skystat, match_down, assert im.meta.background.subtracted is None # test that output models have original sky levels on failure: - for im, lev in zip(result, levels): - assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 + with result: + for im, lev in zip(result, levels): + assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 + result.shelve(im, modify=False) return @@ -243,19 +245,21 @@ def test_skymatch(nircam_rate, skymethod, subtract, skystat, match_down, sub_levels = np.subtract(levels, ref_levels) - for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels): - # check that meta was set correctly: - assert im.meta.background.method == skymethod - assert im.meta.background.subtracted == subtract + with result: + for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels): + # check that meta was set correctly: + assert im.meta.background.method == skymethod + assert im.meta.background.subtracted == subtract - # test computed/measured sky values: - assert abs(im.meta.background.level - rlev) < 0.01 + # test computed/measured sky values: + assert abs(im.meta.background.level - rlev) < 0.01 - # test - if subtract: - assert abs(np.mean(im.data[dq_mask]) - slev) < 0.01 - else: - assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 + # test + if subtract: + assert abs(np.mean(im.data[dq_mask]) - slev) < 0.01 + else: + assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 + result.shelve(im, modify=False) @pytest.mark.parametrize( @@ -334,33 +338,35 @@ def test_skymatch_overlap(nircam_rate, skymethod, subtract, skystat): sub_levels = np.subtract(levels, ref_levels) - for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels): - # check that meta was set correctly: - assert im.meta.background.method == skymethod - assert im.meta.background.subtracted == subtract - - if skymethod in ['local', 'global']: - # These two sky methods must fail because they do not take - # into account (do not compute) overlap regions and use - # entire images: - - # test computed/measured sky values: - assert abs(im.meta.background.level - rlev) > 1000 # FAIL - - # test - if subtract: - assert abs(np.mean(im.data[dq_mask]) - slev) > 1000 # FAIL + with result: + for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels): + # check that meta was set correctly: + assert im.meta.background.method == skymethod + assert im.meta.background.subtracted == subtract + + if skymethod in ['local', 'global']: + # These two sky methods must fail because they do not take + # into account (do not compute) overlap regions and use + # entire images: + + # test computed/measured sky values: + assert abs(im.meta.background.level - rlev) > 1000 # FAIL + + # test + if subtract: + assert abs(np.mean(im.data[dq_mask]) - slev) > 1000 # FAIL + else: + assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 # PASS else: - assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 # PASS - else: - # test computed/measured sky values: - assert abs(im.meta.background.level - rlev) < 0.01 + # test computed/measured sky values: + assert abs(im.meta.background.level - rlev) < 0.01 - # test - if subtract: - assert abs(np.mean(im.data[dq_mask]) - slev) < 0.01 - else: - assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 + # test + if subtract: + assert abs(np.mean(im.data[dq_mask]) - slev) < 0.01 + else: + assert abs(np.mean(im.data[dq_mask]) - lev) < 0.01 + result.shelve(im, modify=False) def test_asn_input(tmp_cwd, nircam_rate, tmp_path): @@ -415,7 +421,6 @@ def test_asn_input(tmp_cwd, nircam_rate, tmp_path): # images are rotated and SATURATED pixels in the corners are not in the # common intersection of all input images. This is the purpose of this test step = SkyMatchStep( - minimize_memory=True, skymethod='match', match_down=True, subtract=True, @@ -426,23 +431,21 @@ def test_asn_input(tmp_cwd, nircam_rate, tmp_path): result = step.run(asn_out_fname) - assert isinstance(result, str) - ref_levels = np.subtract(levels, min(levels)) sub_levels = np.subtract(levels, ref_levels) - result = ModelContainer(result) + with result: + for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels): + # check that meta was set correctly: + assert im.meta.background.method == 'match' + assert im.meta.background.subtracted is True - for im, lev, rlev, slev in zip(result, levels, ref_levels, sub_levels): - # check that meta was set correctly: - assert im.meta.background.method == 'match' - assert im.meta.background.subtracted is True - - # test computed/measured sky values: - assert abs(im.meta.background.level - rlev) < 0.01 + # test computed/measured sky values: + assert abs(im.meta.background.level - rlev) < 0.01 - # test - assert abs(np.mean(im.data[dq_mask]) - slev) < 0.01 + # test + assert abs(np.mean(im.data[dq_mask]) - slev) < 0.01 + result.shelve(im, modify=False) @pytest.mark.parametrize( @@ -498,7 +501,6 @@ def test_skymatch_2x(tmp_cwd, nircam_rate, tmp_path, skymethod, subtract): # images are rotated and SATURATED pixels in the corners are not in the # common intersection of all input images. This is the purpose of this test step = SkyMatchStep( - minimize_memory=True, skymethod=skymethod, match_down=True, subtract=subtract, @@ -529,19 +531,19 @@ def test_skymatch_2x(tmp_cwd, nircam_rate, tmp_path, skymethod, subtract): sub_levels = np.subtract(levels, ref_levels) - result2 = ModelContainer(result2) - # compare results - for im2, lev, rlev, slev in zip(result2, levels, ref_levels, sub_levels): - # check that meta was set correctly: - assert im2.meta.background.method == skymethod - assert im2.meta.background.subtracted == subtract + with result2: + for im2, lev, rlev, slev in zip(result2, levels, ref_levels, sub_levels): + # check that meta was set correctly: + assert im2.meta.background.method == skymethod + assert im2.meta.background.subtracted == subtract - # test computed/measured sky values: - assert abs(im2.meta.background.level - rlev) < 0.01 + # test computed/measured sky values: + assert abs(im2.meta.background.level - rlev) < 0.01 - # test - if subtract: - assert abs(np.mean(im2.data[dq_mask]) - slev) < 0.01 - else: - assert abs(np.mean(im2.data[dq_mask]) - lev) < 0.01 + # test + if subtract: + assert abs(np.mean(im2.data[dq_mask]) - slev) < 0.01 + else: + assert abs(np.mean(im2.data[dq_mask]) - lev) < 0.01 + result2.shelve(im2) diff --git a/jwst/stpipe/core.py b/jwst/stpipe/core.py index 2c86df37c8..2a30139190 100644 --- a/jwst/stpipe/core.py +++ b/jwst/stpipe/core.py @@ -1,16 +1,17 @@ """ JWST-specific Step and Pipeline base classes. """ +import logging + from stdatamodels.jwst.datamodels import JwstDataModel from stdatamodels.jwst import datamodels - -from .. import __version_commit__, __version__ - from stpipe import crds_client from stpipe import Step from stpipe import Pipeline + +from .. import __version_commit__, __version__ from ..lib.suffix import remove_suffix -import logging + log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) diff --git a/jwst/stpipe/utilities.py b/jwst/stpipe/utilities.py index 67a4769113..ad62ba1468 100644 --- a/jwst/stpipe/utilities.py +++ b/jwst/stpipe/utilities.py @@ -36,6 +36,7 @@ import os import re from collections.abc import Sequence +from jwst import datamodels # Configure logging logger = logging.getLogger(__name__) @@ -155,7 +156,7 @@ def record_step_status(datamodel, cal_step, success=True): Parameters ---------- - datamodel : `~jwst.datamodels.JwstDataModel` instance + datamodel : `~jwst.datamodels.JwstDataModel`, `~jwst.datamodels.ModelContainer`, `~jwst.datamodels.ModelLibrary`, str, or Path instance This is the datamodel or container of datamodels to modify in place cal_step : str @@ -172,6 +173,11 @@ def record_step_status(datamodel, cal_step, success=True): if isinstance(datamodel, Sequence): for model in datamodel: model.meta.cal_step._instance[cal_step] = status + elif isinstance(datamodel, datamodels.ModelLibrary): + with datamodel: + for model in datamodel: + model.meta.cal_step._instance[cal_step] = status + datamodel.shelve(model) else: datamodel.meta.cal_step._instance[cal_step] = status diff --git a/jwst/tweakreg/tests/test_multichip_jwst.py b/jwst/tweakreg/tests/test_multichip_jwst.py index 952e61d81e..4402f7b897 100644 --- a/jwst/tweakreg/tests/test_multichip_jwst.py +++ b/jwst/tweakreg/tests/test_multichip_jwst.py @@ -291,7 +291,7 @@ def test_multichip_jwst_alignment(monkeypatch): assert rmse_dec < _REF_RMSE_DEC -def test_multichip_alignment_step(monkeypatch): +def test_multichip_alignment_step_rel(monkeypatch): monkeypatch.setattr(tweakreg_step.twk, 'align_wcs', _align_wcs) monkeypatch.setattr(tweakreg_step, 'make_tweakreg_catalog', _make_tweakreg_catalog) @@ -402,24 +402,31 @@ def test_multichip_alignment_step(monkeypatch): # Alternatively, disable this '_is_wcs_correction_small' test: # step._is_wcs_correction_small = lambda x, y: True - mr, m1, m2 = step.process(mc) - for im in [mr, m1, m2]: - assert im.meta.cal_step.tweakreg == 'COMPLETE' - - wc1 = m1.meta.wcs - wc2 = m2.meta.wcs - - ra1, dec1 = wc1(imcat1['x'], imcat1['y']) - ra2, dec2 = wc2(imcat2['x'], imcat2['y']) - ra = np.concatenate([ra1, ra2]) - dec = np.concatenate([dec1, dec2]) - rra = refcat['RA'] - rdec = refcat['DEC'] - rmse_ra = np.sqrt(np.mean((ra - rra)**2)) - rmse_dec = np.sqrt(np.mean((dec - rdec)**2)) - - assert rmse_ra < _REF_RMSE_RA - assert rmse_dec < _REF_RMSE_DEC + result = step.process(mc) + with result: + for im in result: + assert im.meta.cal_step.tweakreg == 'COMPLETE' + result.shelve(im, modify=False) + + with result: + m1 = result.borrow(1) + m2 = result.borrow(2) + wc1 = m1.meta.wcs + wc2 = m2.meta.wcs + + ra1, dec1 = wc1(imcat1['x'], imcat1['y']) + ra2, dec2 = wc2(imcat2['x'], imcat2['y']) + ra = np.concatenate([ra1, ra2]) + dec = np.concatenate([dec1, dec2]) + rra = refcat['RA'] + rdec = refcat['DEC'] + rmse_ra = np.sqrt(np.mean((ra - rra)**2)) + rmse_dec = np.sqrt(np.mean((dec - rdec)**2)) + + assert rmse_ra < _REF_RMSE_RA + assert rmse_dec < _REF_RMSE_DEC + result.shelve(m1, 1, modify=False) + result.shelve(m2, 2, modify=False) def test_multichip_alignment_step_abs(monkeypatch): diff --git a/jwst/tweakreg/tests/test_tweakreg.py b/jwst/tweakreg/tests/test_tweakreg.py index f17b378def..a038410e69 100644 --- a/jwst/tweakreg/tests/test_tweakreg.py +++ b/jwst/tweakreg/tests/test_tweakreg.py @@ -196,12 +196,18 @@ def test_tweakreg_step(example_input, with_shift): result = step(example_input) # check that step completed - for model in result: - assert model.meta.cal_step.tweakreg == 'COMPLETE' - - # and that the wcses differ by a small amount due to the shift above - # by projecting one point through each wcs and comparing the difference - abs_delta = abs(result[1].meta.wcs(0, 0)[0] - result[0].meta.wcs(0, 0)[0]) + with result: + for model in result: + assert model.meta.cal_step.tweakreg == 'COMPLETE' + result.shelve(model, modify=False) + + # and that the wcses differ by a small amount due to the shift above + # by projecting one point through each wcs and comparing the difference + r0 = result.borrow(0) + r1 = result.borrow(1) + abs_delta = abs(r1.meta.wcs(0, 0)[0] - r0.meta.wcs(0, 0)[0]) + result.shelve(r0, 0, modify=False) + result.shelve(r1, 1, modify=False) if with_shift: assert abs_delta > 1E-5 else: @@ -224,8 +230,10 @@ def test_src_confusion_pars(example_input, alignment_type): result = step(example_input) # check that step was skipped - for model in result: - assert model.meta.cal_step.tweakreg == 'SKIPPED' + with result: + for model in result: + assert model.meta.cal_step.tweakreg == 'SKIPPED' + result.shelve(model) @pytest.fixture() @@ -330,6 +338,7 @@ def test_custom_catalog(custom_catalog_path, example_input, catfile, asn, meta, kwargs = {'use_custom_catalogs': custom} if catfile != "no_catfile": kwargs["catfile"] = str(catfile_path) + step = tweakreg_step.TweakRegStep(**kwargs) # patch _construct_wcs_corrector to check the correct catalog was loaded @@ -349,6 +358,7 @@ def patched_construct_wcs_corrector(wcs, wcsinfo, catalog, group_id, _seen=[]): with pytest.raises(ValueError, match="done testing"): step(str(asn_path)) + @pytest.mark.parametrize("with_shift", [True, False]) def test_sip_approx(example_input, with_shift): """ @@ -376,31 +386,36 @@ def test_sip_approx(example_input, with_shift): # run the step on the example input modified above result = step(example_input) - # output wcs differs by a small amount due to the shift above: - # project one point through each wcs and compare the difference - abs_delta = abs(result[1].meta.wcs(0, 0)[0] - result[0].meta.wcs(0, 0)[0]) - if with_shift: - assert abs_delta > 1E-5 - else: - assert abs_delta < 1E-12 - - # the first wcs is identical to the input and - # does not have SIP approximation keywords -- - # they are normally set by assign_wcs - assert np.allclose(result[0].meta.wcs(0, 0)[0], example_input[0].meta.wcs(0, 0)[0]) - for key in ['ap_order', 'bp_order']: - assert key not in result[0].meta.wcsinfo.instance - - # for the second, SIP approximation should be present - for key in ['ap_order', 'bp_order']: - assert result[1].meta.wcsinfo.instance[key] == 3 - - # evaluate fits wcs and gwcs for the approximation, make sure they agree - wcs_info = result[1].meta.wcsinfo.instance - grid = grid_from_bounding_box(result[1].meta.wcs.bounding_box) - gwcs_ra, gwcs_dec = result[1].meta.wcs(*grid) - fits_wcs = WCS(wcs_info) - fitswcs_res = fits_wcs.pixel_to_world(*grid) + with result: + r0 = result.borrow(0) + r1 = result.borrow(1) + # output wcs differs by a small amount due to the shift above: + # project one point through each wcs and compare the difference + abs_delta = abs(r1.meta.wcs(0, 0)[0] - r0.meta.wcs(0, 0)[0]) + if with_shift: + assert abs_delta > 1E-5 + else: + assert abs_delta < 1E-12 + + # the first wcs is identical to the input and + # does not have SIP approximation keywords -- + # they are normally set by assign_wcs + assert np.allclose(r0.meta.wcs(0, 0)[0], example_input[0].meta.wcs(0, 0)[0]) + for key in ['ap_order', 'bp_order']: + assert key not in r0.meta.wcsinfo.instance + + # for the second, SIP approximation should be present + for key in ['ap_order', 'bp_order']: + assert r1.meta.wcsinfo.instance[key] == 3 + + # evaluate fits wcs and gwcs for the approximation, make sure they agree + wcs_info = r1.meta.wcsinfo.instance + grid = grid_from_bounding_box(r1.meta.wcs.bounding_box) + gwcs_ra, gwcs_dec = r1.meta.wcs(*grid) + fits_wcs = WCS(wcs_info) + fitswcs_res = fits_wcs.pixel_to_world(*grid) + result.shelve(r0, 0, modify=False) + result.shelve(r1, 1, modify=False) assert np.allclose(fitswcs_res.ra.deg, gwcs_ra) assert np.allclose(fitswcs_res.dec.deg, gwcs_dec) diff --git a/jwst/tweakreg/tweakreg_step.py b/jwst/tweakreg/tweakreg_step.py index 5255769fa2..0d8363771b 100644 --- a/jwst/tweakreg/tweakreg_step.py +++ b/jwst/tweakreg/tweakreg_step.py @@ -12,9 +12,9 @@ import stcal.tweakreg.tweakreg as twk -from jwst.datamodels import ModelContainer from jwst.stpipe import record_step_status from jwst.assign_wcs.util import update_fits_wcsinfo, update_s_region_imaging +from jwst.datamodels import ModelLibrary # LOCAL from ..stpipe import Step @@ -123,12 +123,16 @@ class TweakRegStep(Step): # stpipe general options output_use_model = boolean(default=True) # When saving use `DataModel.meta.filename` + in_memory = boolean(default=True) # If False, preserve memory using temporary files at expense of runtime """ reference_file_types = [] def process(self, input): - images = ModelContainer(input) + if isinstance(input, ModelLibrary): + images = input + else: + images = ModelLibrary(input, on_disk=not self.in_memory) if len(images) == 0: raise ValueError("Input must contain at least one image model.") @@ -153,15 +157,15 @@ def process(self, input): ) use_custom_catalogs = False # else, load from association - elif hasattr(images.meta, "asn_table") and getattr(images, "asn_file_path", None) is not None: + elif images._asn_dir is not None: catdict = {} - asn_dir = path.dirname(images.asn_file_path) - for member in images.meta.asn_table.products[0].members: - if hasattr(member, "tweakreg_catalog"): - if member.tweakreg_catalog is None or not member.tweakreg_catalog.strip(): - catdict[member.expname] = None + for member in images.asn["products"][0]["members"]: + if "tweakreg_catalog" in member: + tweakreg_catalog = member["tweakreg_catalog"] + if tweakreg_catalog is None or not tweakreg_catalog.strip(): + catdict[member["expname"]] = None else: - catdict[member.expname] = path.join(asn_dir, member.tweakreg_catalog) + catdict[member["expname"]] = path.join(images._asn_dir, tweakreg_catalog) if self.abs_refcat is not None and self.abs_refcat.strip(): align_to_abs_refcat = True @@ -186,62 +190,64 @@ def process(self, input): # pre-allocate collectors (same length and order as images) correctors = [None] * len(images) - # Build the catalog for each input image - for (model_index, image_model) in enumerate(images): - # now that the model is open, check it's metadata for a custom catalog - # only if it's not listed in the catdict - if use_custom_catalogs and image_model.meta.filename not in catdict: - if (image_model.meta.tweakreg_catalog is not None and image_model.meta.tweakreg_catalog.strip()): - catdict[image_model.meta.filename] = image_model.meta.tweakreg_catalog - if use_custom_catalogs and catdict.get(image_model.meta.filename, None) is not None: - # FIXME this modifies the input_model - image_model.meta.tweakreg_catalog = catdict[image_model.meta.filename] - # use user-supplied catalog: - self.log.info("Using user-provided input catalog " - f"'{image_model.meta.tweakreg_catalog}'") - catalog = Table.read( - image_model.meta.tweakreg_catalog, - ) - save_catalog = False - else: - # source finding - catalog = self._find_sources(image_model) - - # only save if catalog was computed from _find_sources and - # the user requested save_catalogs - save_catalog = self.save_catalogs - - # if needed rename xcentroid to x, ycentroid to y - catalog = _rename_catalog_columns(catalog) - - # filter all sources outside the wcs bounding box - catalog = twk.filter_catalog_by_bounding_box( - catalog, - image_model.meta.wcs.bounding_box) - - # setting 'name' is important for tweakwcs logging - if catalog.meta.get('name') is None: - catalog.meta['name'] = path.splitext(image_model.meta.filename)[0].strip('_- ') - - # log results of source finding (or user catalog) - filename = image_model.meta.filename - nsources = len(catalog) - if nsources == 0: - self.log.warning('No sources found in {}.'.format(filename)) - else: - self.log.info('Detected {} sources in {}.' - .format(len(catalog), filename)) - - # save catalog (if requested) - if save_catalog: - # FIXME this modifies the input_model - image_model.meta.tweakreg_catalog = self._write_catalog(catalog, filename) - - # construct the corrector since the model is open (and already has a group_id) - correctors[model_index] = twk.construct_wcs_corrector(image_model.meta.wcs, - image_model.meta.wcsinfo.instance, - catalog, - image_model.meta.group_id,) + # Build the catalog and corrector for each input images + with images: + for (model_index, image_model) in enumerate(images): + # now that the model is open, check its metadata for a custom catalog + # only if it's not listed in the catdict + if use_custom_catalogs and image_model.meta.filename not in catdict: + if (image_model.meta.tweakreg_catalog is not None and image_model.meta.tweakreg_catalog.strip()): + catdict[image_model.meta.filename] = image_model.meta.tweakreg_catalog + if use_custom_catalogs and catdict.get(image_model.meta.filename, None) is not None: + image_model.meta.tweakreg_catalog = catdict[image_model.meta.filename] + # use user-supplied catalog: + self.log.info("Using user-provided input catalog " + f"'{image_model.meta.tweakreg_catalog}'") + catalog = Table.read( + image_model.meta.tweakreg_catalog, + ) + save_catalog = False + else: + # source finding + catalog = self._find_sources(image_model) + + # only save if catalog was computed from _find_sources and + # the user requested save_catalogs + save_catalog = self.save_catalogs + + # if needed rename xcentroid to x, ycentroid to y + catalog = _rename_catalog_columns(catalog) + + # filter all sources outside the wcs bounding box + catalog = twk.filter_catalog_by_bounding_box( + catalog, + image_model.meta.wcs.bounding_box) + + # setting 'name' is important for tweakwcs logging + if catalog.meta.get('name') is None: + catalog.meta['name'] = path.splitext(image_model.meta.filename)[0].strip('_- ') + + # log results of source finding (or user catalog) + filename = image_model.meta.filename + nsources = len(catalog) + if nsources == 0: + self.log.warning('No sources found in {}.'.format(filename)) + else: + self.log.info('Detected {} sources in {}.' + .format(len(catalog), filename)) + + # save catalog (if requested) + if save_catalog: + # FIXME this modifies the input_model + image_model.meta.tweakreg_catalog = self._write_catalog(catalog, filename) + + # construct the corrector since the model is open (and already has a group_id) + correctors[model_index] = \ + twk.construct_wcs_corrector(image_model.meta.wcs, + image_model.meta.wcsinfo.instance, + catalog, + image_model.meta.group_id,) + images.shelve(image_model, model_index) self.log.info('') self.log.info("Number of image groups to be aligned: {:d}." @@ -277,34 +283,36 @@ def process(self, input): # absolute alignment to the reference catalog # can (and does) occur after alignment between groups if align_to_abs_refcat: - try: - ref_image = images[0] - correctors = \ - twk.absolute_align(correctors, self.abs_refcat, - ref_wcs=ref_image.meta.wcs, - ref_wcsinfo=ref_image.meta.wcsinfo.instance, - epoch=Time(ref_image.meta.observation.date).decimalyear, - abs_minobj=self.abs_minobj, - abs_fitgeometry=self.abs_fitgeometry, - abs_nclip=self.abs_nclip, - abs_sigma=self.abs_sigma, - abs_searchrad=self.abs_searchrad, - abs_use2dhist=self.abs_use2dhist, - abs_separation=self.abs_separation, - abs_tolerance=self.abs_tolerance, - save_abs_catalog=self.save_abs_catalog, - abs_catalog_output_dir=self.output_dir, - ) - - except twk.TweakregError as e: - self.log.warning(str(e)) - for model in images: - model.meta.cal_step.tweakreg = "SKIPPED" - return images - - if local_align_failed and not align_to_abs_refcat: - for model in images: - record_step_status(model, "tweakreg", success=False) + with images: + ref_image = images.borrow(0) + try: + correctors = \ + twk.absolute_align(correctors, self.abs_refcat, + ref_wcs=ref_image.meta.wcs, + ref_wcsinfo=ref_image.meta.wcsinfo.instance, + epoch=Time(ref_image.meta.observation.date).decimalyear, + abs_minobj=self.abs_minobj, + abs_fitgeometry=self.abs_fitgeometry, + abs_nclip=self.abs_nclip, + abs_sigma=self.abs_sigma, + abs_searchrad=self.abs_searchrad, + abs_use2dhist=self.abs_use2dhist, + abs_separation=self.abs_separation, + abs_tolerance=self.abs_tolerance, + save_abs_catalog=self.save_abs_catalog, + abs_catalog_output_dir=self.output_dir, + ) + images.shelve(ref_image, 0, modify=False) + except twk.TweakregError as e: + self.log.warning(str(e)) + images.shelve(ref_image, 0, modify=False) + record_step_status(images, "tweakreg", success=False) + return images + finally: + del ref_image + + if local_align_failed and not align_to_abs_refcat: + record_step_status(images, "tweakreg", success=False) return images # one final pass through all the models to update them based @@ -315,53 +323,53 @@ def process(self, input): def _apply_tweakreg_solution(self, - images: ModelContainer, + images: ModelLibrary, correctors: list[JWSTWCSCorrector], align_to_abs_refcat: bool = False, - ) -> ModelContainer: - - for (image_model, corrector) in zip(images, correctors): - - # retrieve fit status and update wcs if fit is successful: - if ("fit_info" in corrector.meta and - "SUCCESS" in corrector.meta["fit_info"]["status"]): - - # Update/create the WCS .name attribute with information - # on this astrometric fit as the only record that it was - # successful: - if align_to_abs_refcat: - # NOTE: This .name attrib agreed upon by the JWST Cal - # Working Group. - # Current value is merely a place-holder based - # on HST conventions. This value should also be - # translated to the FITS WCSNAME keyword - # IF that is what gets recorded in the archive - # for end-user searches. - corrector.wcs.name = f"FIT-LVL3-{self.abs_refcat}" - - image_model.meta.wcs = corrector.wcs - update_s_region_imaging(image_model) - - # Also update FITS representation in input exposures for - # subsequent reprocessing by the end-user. - if self.sip_approx: - try: - update_fits_wcsinfo( - image_model, - max_pix_error=self.sip_max_pix_error, - degree=self.sip_degree, - max_inv_pix_error=self.sip_max_inv_pix_error, - inv_degree=self.sip_inv_degree, - npoints=self.sip_npoints, - crpix=None - ) - except (ValueError, RuntimeError) as e: - self.log.warning("Failed to update 'meta.wcsinfo' with FITS SIP " - "approximation. Reported error is:") - self.log.warning(f'"{e.args[0]}"') - record_step_status(image_model, "tweakreg", success=True) - - return image_model + ) -> ModelLibrary: + with images: + for (image_model, corrector) in zip(images, correctors): + + # retrieve fit status and update wcs if fit is successful: + if ("fit_info" in corrector.meta and + "SUCCESS" in corrector.meta["fit_info"]["status"]): + + # Update/create the WCS .name attribute with information + # on this astrometric fit as the only record that it was + # successful: + if align_to_abs_refcat: + # NOTE: This .name attrib agreed upon by the JWST Cal + # Working Group. + # Current value is merely a place-holder based + # on HST conventions. This value should also be + # translated to the FITS WCSNAME keyword + # IF that is what gets recorded in the archive + # for end-user searches. + corrector.wcs.name = f"FIT-LVL3-{self.abs_refcat}" + + image_model.meta.wcs = corrector.wcs + update_s_region_imaging(image_model) + + # Also update FITS representation in input exposures for + # subsequent reprocessing by the end-user. + if self.sip_approx: + try: + update_fits_wcsinfo( + image_model, + max_pix_error=self.sip_max_pix_error, + degree=self.sip_degree, + max_inv_pix_error=self.sip_max_inv_pix_error, + inv_degree=self.sip_inv_degree, + npoints=self.sip_npoints, + crpix=None + ) + except (ValueError, RuntimeError) as e: + self.log.warning("Failed to update 'meta.wcsinfo' with FITS SIP " + "approximation. Reported error is:") + self.log.warning(f'"{e.args[0]}"') + record_step_status(image_model, "tweakreg", success=True) + images.shelve(image_model) + return images def _write_catalog(self, catalog, filename):