diff --git a/CHANGES.rst b/CHANGES.rst index 5203efd1..f0e59c85 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,7 +9,8 @@ General Changes to API -------------- -- +- Add ``outlier_detection`` submodule with ``utils`` included + from jwst. [#270] Bug Fixes --------- diff --git a/docs/stcal/outlier_detection/description.rst b/docs/stcal/outlier_detection/description.rst new file mode 100644 index 00000000..5299257b --- /dev/null +++ b/docs/stcal/outlier_detection/description.rst @@ -0,0 +1,4 @@ +Description +============ + +This sub-package contains functions useful for outlier detection. diff --git a/docs/stcal/outlier_detection/index.rst b/docs/stcal/outlier_detection/index.rst new file mode 100644 index 00000000..7c3bac07 --- /dev/null +++ b/docs/stcal/outlier_detection/index.rst @@ -0,0 +1,12 @@ +.. _outlier_detection: + +======================= +Outlier Detection Utils +======================= + +.. toctree:: + :maxdepth: 2 + + description.rst + +.. automodapi:: stcal.outlier_detection.utils diff --git a/docs/stcal/package_index.rst b/docs/stcal/package_index.rst index 5af815e2..e1db02b0 100644 --- a/docs/stcal/package_index.rst +++ b/docs/stcal/package_index.rst @@ -8,3 +8,4 @@ Package Index ramp_fitting/index.rst alignment/index.rst tweakreg/index.rst + outlier_detection/index.rst diff --git a/pyproject.toml b/pyproject.toml index 7a16bc84..978e98fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,9 @@ classifiers = [ ] dependencies = [ "astropy >=5.0.4", + "drizzle>=1.15.0", "scipy >=1.7.2", + "scikit-image>=0.19", "numpy >=1.21.2", "opencv-python-headless >=4.6.0.66", "asdf >=2.15.0", @@ -209,6 +211,7 @@ module = [ "stdatamodels.*", "asdf.*", "scipy.*", + "drizzle.*", # don't complain about the installed c parts of this library "stcal.ramp_fitting.ols_cas22._fit", "stcal.ramp_fitting.ols_cas22._jump", diff --git a/src/stcal/outlier_detection/__init__.py b/src/stcal/outlier_detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/stcal/outlier_detection/utils.py b/src/stcal/outlier_detection/utils.py new file mode 100644 index 00000000..964de5fe --- /dev/null +++ b/src/stcal/outlier_detection/utils.py @@ -0,0 +1,339 @@ +""" +Utility functions for outlier detection routines +""" +import warnings + +import numpy as np +from astropy.stats import sigma_clip +from drizzle.cdrizzle import tblot +from scipy import ndimage +from skimage.util import view_as_windows +import gwcs + +from stcal.alignment.util import wcs_bbox_from_shape + +import logging +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +__all__ = [ + "medfilt", + "compute_weight_threshold", + "flag_crs", + "flag_resampled_crs", + "gwcs_blot", + "calc_gwcs_pixmap", + "reproject", +] + + +def medfilt(arr, kern_size): + """ + scipy.signal.medfilt (and many other median filters) have undefined behavior + for nan inputs. See: https://github.com/scipy/scipy/issues/4800 + + Parameters + ---------- + arr : numpy.ndarray + The input array + + kern_size : list of int + List of kernel dimensions, length must be equal to arr.ndim. + + Returns + ------- + filtered_arr : numpy.ndarray + Input array median filtered with a kernel of size kern_size + """ + padded = np.pad(arr, [[k // 2] for k in kern_size]) + windows = view_as_windows(padded, kern_size, np.ones(len(kern_size), dtype='int')) + return np.nanmedian(windows, axis=np.arange(-len(kern_size), 0)) + + +def compute_weight_threshold(weight, maskpt): + ''' + Compute the weight threshold for a single image or cube. + + Parameters + ---------- + weight : numpy.ndarray + The weight array + + maskpt : float + The percentage of the mean weight to use as a threshold for masking. + + Returns + ------- + float + The weight threshold for this integration. + ''' + # necessary in order to assure that mask gets applied correctly + if hasattr(weight, '_mask'): + del weight._mask + mask_zero_weight = np.equal(weight, 0.) + mask_nans = np.isnan(weight) + # Combine the masks + weight_masked = np.ma.array(weight, mask=np.logical_or( + mask_zero_weight, mask_nans)) + # Sigma-clip the unmasked data + weight_masked = sigma_clip(weight_masked, sigma=3, maxiters=5) + mean_weight = np.mean(weight_masked) + # Mask pixels where weight falls below maskpt percent + weight_threshold = mean_weight * maskpt + return weight_threshold + + +def _abs_deriv(array): + """ + Do not use this function. + + Take the absolute derivative of a numpy array. + + This function assumes off-edge pixel values are 0 + and leads to erroneous derivative values and should + likely not be used. + """ + tmp = np.zeros(array.shape, dtype=np.float64) + out = np.zeros(array.shape, dtype=np.float64) + + tmp[1:, :] = array[:-1, :] + tmp, out = _absolute_subtract(array, tmp, out) + tmp[:-1, :] = array[1:, :] + tmp, out = _absolute_subtract(array, tmp, out) + + tmp[:, 1:] = array[:, :-1] + tmp, out = _absolute_subtract(array, tmp, out) + tmp[:, :-1] = array[:, 1:] + tmp, out = _absolute_subtract(array, tmp, out) + + return out + + +def _absolute_subtract(array, tmp, out): + """ + Do not use this function. + + A helper function for _abs_deriv. + """ + tmp = np.abs(array - tmp) + out = np.maximum(tmp, out) + tmp = tmp * 0. + return tmp, out + + +def flag_crs( + sci_data, + sci_err, + blot_data, + snr, +): + """ + Straightforward detection of outliers for non-dithered data since + sci_err includes all noise sources (photon, read, and flat for baseline). + + Parameters + ---------- + sci_data : numpy.ndarray + "Science" data possibly containing outliers. + + sci_err : numpy.ndarray + Error estimates for sci_data. + + blot_data : numpy.ndarray + Reference data used to detect outliers. + + snr : float + Signal-to-noise ratio used during detection. + + Returns + ------- + cr_mask : numpy.ndarray + Boolean array where outliers (CRs) are true. + """ + return np.greater(np.abs(sci_data - blot_data), snr * np.nan_to_num(sci_err)) + + +def flag_resampled_crs( + sci_data, + sci_err, + blot_data, + snr1, + snr2, + scale1, + scale2, + backg, +): + """ + Detect outliers (CRs) using resampled reference data. + + Parameters + ---------- + + sci_data : numpy.ndarray + "Science" data possibly containing outliers + + sci_err : numpy.ndarray + Error estimates for sci_data + + blot_data : numpy.ndarray + Reference data used to detect outliers. + + snr1 : float + Signal-to-noise ratio threshold used prior to smoothing. + + snr2 : float + Signal-to-noise ratio threshold used after smoothing. + + scale1 : float + Scale used prior to smoothing. + + scale2 : float + Scale used after smoothing. + + backg : float + Scalar background to subtract from the difference. + + Returns + ------- + cr_mask : numpy.ndarray + boolean array where outliers (CRs) are true + """ + err_data = np.nan_to_num(sci_err) + + blot_deriv = _abs_deriv(blot_data) + diff_noise = np.abs(sci_data - blot_data - backg) + + # Create a boolean mask based on a scaled version of + # the derivative image (dealing with interpolating issues?) + # and the standard n*sigma above the noise + threshold1 = scale1 * blot_deriv + snr1 * err_data + mask1 = np.greater(diff_noise, threshold1) + + # Smooth the boolean mask with a 3x3 boxcar kernel + kernel = np.ones((3, 3), dtype=int) + mask1_smoothed = ndimage.convolve(mask1, kernel, mode='nearest') + + # Create a 2nd boolean mask based on the 2nd set of + # scale and threshold values + threshold2 = scale2 * blot_deriv + snr2 * err_data + mask2 = np.greater(diff_noise, threshold2) + + # Final boolean mask + return mask1_smoothed & mask2 + + +def gwcs_blot(median_data, median_wcs, blot_shape, blot_wcs, pix_ratio): + """ + Resample the median data to recreate an input image based on + the blot wcs. + + Parameters + ---------- + median_data : numpy.ndarray + The data to blot. + + median_wcs : gwcs.wcs.WCS + The wcs for the median data. + + blot_shape : list of int + The target blot data shape. + + blot_wcs : gwcs.wcs.WCS + The target/blotted wcs. + + pix_ratio : float + Pixel ratio. + + Returns + ------- + blotted : numpy.ndarray + The blotted median data. + + blot_img : datamodel + Datamodel containing header and WCS to define the 'blotted' image + """ + # Compute the mapping between the input and output pixel coordinates + pixmap = calc_gwcs_pixmap(blot_wcs, median_wcs, blot_shape) + log.debug("Pixmap shape: {}".format(pixmap[:, :, 0].shape)) + log.debug("Sci shape: {}".format(blot_shape)) + log.info('Blotting {} <-- {}'.format(blot_shape, median_data.shape)) + + outsci = np.zeros(blot_shape, dtype=np.float32) + + # Currently tblot cannot handle nans in the pixmap, so we need to give some + # other value. -1 is not optimal and may have side effects. But this is + # what we've been doing up until now, so more investigation is needed + # before a change is made. Preferably, fix tblot in drizzle. + pixmap[np.isnan(pixmap)] = -1 + tblot(median_data, pixmap, outsci, scale=pix_ratio, kscale=1.0, + interp='linear', exptime=1.0, misval=0.0, sinscl=1.0) + + return outsci + + +def calc_gwcs_pixmap(in_wcs, out_wcs, in_shape): + """ + Return a pixel grid map from input frame to output frame. + + Parameters + ---------- + in_wcs : gwcs.wcs.WCS + Input/source wcs. + + out_wcs : gwcs.wcs.WCS + Output/projected wcs. + + in_shape : list of int + Input shape used to compute the input bounding box. + + Returns + ------- + pixmap : numpy.ndarray + Computed pixmap. + """ + bb = wcs_bbox_from_shape(in_shape) + log.debug("Bounding box from data shape: {}".format(bb)) + + grid = gwcs.wcstools.grid_from_bounding_box(bb) + return np.dstack(reproject(in_wcs, out_wcs)(grid[0], grid[1])) + + +def reproject(wcs1, wcs2): + """ + Given two WCSs return a function which takes pixel + coordinates in wcs1 and computes them in wcs2. + + It performs the forward transformation of ``wcs1`` followed by the + inverse of ``wcs2``. + + Parameters + ---------- + wcs1, wcs2 : gwcs.wcs.WCS + WCS objects that have `pixel_to_world_values` and `world_to_pixel_values` + methods. + + Returns + ------- + _reproject : + Function to compute the transformations. It takes x, y + positions in ``wcs1`` and returns x, y positions in ``wcs2``. + """ + + try: + forward_transform = wcs1.pixel_to_world_values + backward_transform = wcs2.world_to_pixel_values + except AttributeError as err: + raise TypeError("Input should be a WCS") from err + + def _reproject(x, y): + sky = forward_transform(x, y) + flat_sky = [] + for axis in sky: + flat_sky.append(axis.flatten()) + det = backward_transform(*tuple(flat_sky)) + det_reshaped = [] + for axis in det: + det_reshaped.append(axis.reshape(x.shape)) + return tuple(det_reshaped) + return _reproject diff --git a/tests/outlier_detection/__init__.py b/tests/outlier_detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/outlier_detection/test_utils.py b/tests/outlier_detection/test_utils.py new file mode 100644 index 00000000..c602a89c --- /dev/null +++ b/tests/outlier_detection/test_utils.py @@ -0,0 +1,183 @@ +import warnings + +import gwcs +import pytest +import numpy as np +import scipy.signal +from astropy.modeling import models + +from stcal.outlier_detection.utils import ( + _abs_deriv, + compute_weight_threshold, + flag_crs, + flag_resampled_crs, + gwcs_blot, + calc_gwcs_pixmap, + reproject, + medfilt, +) + + +@pytest.mark.parametrize("shape,diff", [ + ([5, 7], 100), + ([17, 13], -200), +]) +def test_abs_deriv_single_value(shape, diff): + arr = np.zeros(shape) + # put diff at the center + np.put(arr, arr.size // 2, diff) + # since abs_deriv with a single non-zero value is the same as a + # convolution with a 3x3 cross kernel use it to test the result + expected = scipy.signal.convolve2d(np.abs(arr), [[0, 1, 0], [1, 1, 1], [0, 1, 0]], mode='same') + result = _abs_deriv(arr) + np.testing.assert_allclose(result, expected) + + +@pytest.mark.skip(reason="_abs_deriv has edge effects due to treating off-edge pixels as 0: see JP-3683") +@pytest.mark.parametrize("nrows,ncols", [(5, 5), (7, 11), (17, 13)]) +def test_abs_deriv_range(nrows, ncols): + arr = np.arange(nrows * ncols).reshape(nrows, ncols) + result = _abs_deriv(arr) + np.testing.assert_allclose(result, ncols) + + +@pytest.mark.parametrize("shape,mean,maskpt,expected", [ + ([5, 5], 11, 0.5, 5.5), + ([5, 5], 11, 0.25, 2.75), + ([3, 3, 3], 17, 0.5, 8.5), +]) +def test_compute_weight_threshold(shape, mean, maskpt, expected): + arr = np.ones(shape, dtype=np.float32) * mean + result = compute_weight_threshold(arr, maskpt) + np.testing.assert_allclose(result, expected) + + +def test_compute_weight_threshold_outlier(): + """ + Test that a large outlier doesn't bias the threshold + """ + arr = np.ones([7, 7, 7], dtype=np.float32) * 42 + arr[3, 3] = 9000 + result = compute_weight_threshold(arr, 0.5) + np.testing.assert_allclose(result, 21) + + +def test_compute_weight_threshold_zeros(): + """ + Test that zeros are ignored + """ + arr = np.zeros([10, 10], dtype=np.float32) + arr[:5, :5] = 42 + result = compute_weight_threshold(arr, 0.5) + np.testing.assert_allclose(result, 21) + + +def test_flag_crs(): + sci = np.zeros((10, 10), dtype=np.float32) + err = np.ones_like(sci) + blot = np.zeros_like(sci) + # add a cr + sci[2, 3] = 10 + crs = flag_crs(sci, err, blot, 1) + ys, xs = np.where(crs) + np.testing.assert_equal(ys, 2) + np.testing.assert_equal(xs, 3) + + +def test_flag_resampled_crs(): + sci = np.zeros((10, 10), dtype=np.float32) + err = np.ones_like(sci) + blot = np.zeros_like(sci) + # add a cr + sci[2, 3] = 10 + + snr1, snr2 = 5, 4 + scale1, scale2 = 1.2, 0.7 + backg = 0.0 + crs = flag_resampled_crs(sci, err, blot, snr1, snr2, scale1, scale2, backg) + ys, xs = np.where(crs) + np.testing.assert_equal(ys, 2) + np.testing.assert_equal(xs, 3) + + +def test_gwcs_blot(): + # set up a very simple wcs that scales by 1x + output_frame = gwcs.Frame2D(name="world") + forward_transform = models.Scale(1) & models.Scale(1) + + median_data = np.arange(100, dtype=np.float32).reshape((10, 10)) + median_wcs = gwcs.WCS(forward_transform, output_frame=output_frame) + blot_shape = (5, 5) + blot_wcs = gwcs.WCS(forward_transform, output_frame=output_frame) + pix_ratio = 1.0 + + blotted = gwcs_blot(median_data, median_wcs, blot_shape, blot_wcs, pix_ratio) + # since the median data is larger and the wcs are equivalent the blot + # will window the data to the shape of the blot data + assert blotted.shape == blot_shape + np.testing.assert_equal(blotted, median_data[:blot_shape[0], :blot_shape[1]]) + + +def test_calc_gwcs_pixmap(): + # generate 2 wcses with different scales + output_frame = gwcs.Frame2D(name="world") + in_transform = models.Scale(1) & models.Scale(1) + out_transform = models.Scale(2) & models.Scale(2) + in_wcs = gwcs.WCS(in_transform, output_frame=output_frame) + out_wcs = gwcs.WCS(out_transform, output_frame=output_frame) + in_shape = (3, 4) + pixmap = calc_gwcs_pixmap(in_wcs, out_wcs, in_shape) + # we expect given the 2x scale difference to have a pixmap + # with pixel coordinates / 2 + # use mgrid to generate these coordinates (and reshuffle to match the pixmap) + expected = np.swapaxes(np.mgrid[:4, :3] / 2., 0, 2) + np.testing.assert_equal(pixmap, expected) + + +def test_reproject(): + # generate 2 wcses with different scales + output_frame = gwcs.Frame2D(name="world") + wcs1 = gwcs.WCS(models.Scale(1) & models.Scale(1), output_frame=output_frame) + wcs2 = gwcs.WCS(models.Scale(2) & models.Scale(2), output_frame=output_frame) + project = reproject(wcs1, wcs2) + pys, pxs = project(np.array([3]), np.array([1])) + np.testing.assert_equal(pys, 1.5) + np.testing.assert_equal(pxs, 0.5) + + +@pytest.mark.parametrize("shape,kern_size", [ + ([7, 7], [3, 3]), + ([7, 7], [3, 1]), + ([7, 7], [1, 3]), + ([7, 5], [3, 3]), + ([5, 7], [3, 3]), + ([42, 42], [7, 7]), + ([42, 42], [7, 5]), + ([42, 42], [5, 7]), + ([42, 7, 5], [3, 3, 3]), + ([5, 7, 42], [5, 5, 5]), +]) +def test_medfilt_against_scipy(shape, kern_size): + arr = np.arange(np.prod(shape), dtype='uint32').reshape(shape) + result = medfilt(arr, kern_size) + + # The use of scipy.signal.medfilt is ok here ONLY because the + # input has no nans. See the medfilt docstring + expected = scipy.signal.medfilt(arr, kern_size) + + np.testing.assert_allclose(result, expected) + + +@pytest.mark.parametrize("arr,kern_size,expected", [ + ([2, np.nan, 0], [3], [1, 1, 0]), + ([np.nan, np.nan, np.nan], [3], [0, np.nan, 0]), +]) +def test_medfilt_nan(arr, kern_size, expected): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="All-NaN slice", + category=RuntimeWarning + ) + result = medfilt(arr, kern_size) + np.testing.assert_allclose(result, expected) diff --git a/tox.ini b/tox.ini index cc514980..da995287 100644 --- a/tox.ini +++ b/tox.ini @@ -60,10 +60,14 @@ deps = oldestdeps: minimum_dependencies devdeps: numpy>=0.0.dev0 devdeps: scipy>=0.0.dev0 + devdeps: scikit-image>=0.0.dev0 devdeps: pyerfa>=0.0.dev0 devdeps: astropy>=0.0.dev0 devdeps: requests>=0.0.dev0 devdeps: tweakwcs @ git+https://github.com/spacetelescope/tweakwcs.git + devdeps: asdf @ git+https://github.com/asdf-format/asdf.git + devdeps: drizzle @ git+https://github.com/spacetelescope/drizzle.git + devdeps: gwcs @ git+https://github.com/spacetelescope/gwcs.git use_develop = true pass_env = CI