Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make var_flat optional. #1438

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1438.flatfield.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make var_flat optional.
7 changes: 7 additions & 0 deletions docs/roman/flatfield/main.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ and finally the error that is associated with the science data is given by,

The total ERR array in the science exposure is updated as the square root
of the quadratic sum of VAR_POISSON, VAR_RNOISE, and VAR_FLAT.

Note that by default we do not compute VAR_FLAT nor include its
contribution to ERR, unless the "include_var_flat" is specified. This
means that the uncertainties on very bright pixels are
underestimated. However, other effects like charge migration,
saturation, and non-linearity can be important at these flux levels,
and their contributions to the uncertainty are never included.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies = [
"photutils >=1.13.0",
"pyparsing >=2.4.7",
"requests >=2.26",
# "roman_datamodels>=0.22.0,<0.23.0",
"roman_datamodels>=0.22.0,<0.23.0",
"rad @ git+https://github.com/schlafly/rad.git@remove-var-flat",
"roman_datamodels @ git+https://github.com/spacetelescope/roman_datamodels.git",
"scipy >=1.11",
# "stcal>=1.10.0,<1.11.0",
Expand Down
36 changes: 25 additions & 11 deletions romancal/flatfield/flat_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MICRONS_100 = 1.0e-4 # 100 microns, in meters


def do_correction(input_model, flat=None):
def do_correction(input_model, flat=None, include_var_flat=False):
"""Flat-field a Roman data model using a flat-field model

Parameters
Expand All @@ -24,19 +24,22 @@ def do_correction(input_model, flat=None):
flat : Roman data model, or None
Data model containing flat-field for all instruments

include_var_flat : bool
compute & store the flat field variance?

Returns
-------
output_model : data model
The data model for the flat-fielded science data.
The data is modified in place.
"""

do_flat_field(input_model, flat)
do_flat_field(input_model, flat, include_var_flat=include_var_flat)

return input_model


def do_flat_field(output_model, flat_model):
def do_flat_field(output_model, flat_model, include_var_flat=False):
"""Apply flat-fielding, and update the output model.

Parameters
Expand All @@ -46,6 +49,9 @@ def do_flat_field(output_model, flat_model):

flat_model : Roman data model
data model containing flat-field

include_var_flat : bool
compute & store the flat field variance?
"""
if flat_model is not None and output_model.data.shape != flat_model.data.shape:
# Check to see if flat data array is smaller than science data
Expand All @@ -61,11 +67,11 @@ def do_flat_field(output_model, flat_model):
log.info("Skipping flat field - no flat reference file.")
output_model.meta.cal_step.flat_field = "SKIPPED"
else:
apply_flat_field(output_model, flat_model)
apply_flat_field(output_model, flat_model, include_var_flat=include_var_flat)
output_model.meta.cal_step.flat_field = "COMPLETE"


def apply_flat_field(science, flat):
def apply_flat_field(science, flat, include_var_flat=False):
"""Flat field the data and error arrays.

Extended summary
Expand All @@ -82,6 +88,9 @@ def apply_flat_field(science, flat):

flat : Roman data model
flat field data model

include_var_flat : bool
compute & store the flat vield variance?
"""
flat_data = flat.data.copy()
flat_dq = flat.dq.copy()
Expand Down Expand Up @@ -111,13 +120,18 @@ def apply_flat_field(science, flat):
flat_data_squared = flat_data**2
science.var_poisson /= flat_data_squared
science.var_rnoise /= flat_data_squared
try:
science.var_flat = science.data**2 / flat_data_squared * flat_err**2
except AttributeError:
science["var_flat"] = np.zeros(shape=science.data.shape, dtype=np.float32)
science.var_flat = science.data**2 / flat_data_squared * flat_err**2

science.err = np.sqrt(science.var_poisson + science.var_rnoise + science.var_flat)
total_var = science.var_poisson + science.var_rnoise
if include_var_flat:
var_flat = science.data**2 / flat_data_squared * flat_err**2
try:
science.var_flat = var_flat
except AttributeError:
science["var_flat"] = np.zeros(shape=science.data.shape, dtype=np.float32)
science.var_flat = var_flat
total_var += science.var_flat

science.err = np.sqrt(total_var)

# Combine the science and flat DQ arrays
science.dq = np.bitwise_or(science.dq, flat_dq)
6 changes: 4 additions & 2 deletions romancal/flatfield/flat_field_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class FlatFieldStep(RomanStep):
"""Flat-field a science image using a flatfield reference image."""

class_alias = "flat_field"
spec = """
include_var_flat = boolean(default=False) # include flat field variance
""" # noqa: E501

reference_file_types = ["flat"]

Expand All @@ -38,8 +41,7 @@ def process(self, input_model):

# Do the flat-field correction
output_model = flat_field.do_correction(
input_model,
reference_file_model,
input_model, reference_file_model, include_var_flat=self.include_var_flat
)

# Close reference file
Expand Down
15 changes: 15 additions & 0 deletions romancal/flatfield/tests/test_flatfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ def test_crds_temporal_match(instrument, exptype):
)


def test_skip_var_flat():
"""Test that we don't populate var_flat if requested."""

wfi_image1 = maker_utils.mk_level2_image()
wfi_image2 = maker_utils.mk_level2_image()
del wfi_image1["var_flat"]
del wfi_image2["var_flat"]
wfi_image_model1 = ImageModel(wfi_image1)
wfi_image_model2 = ImageModel(wfi_image2)
result1 = FlatFieldStep.call(wfi_image_model1, include_var_flat=False)
result2 = FlatFieldStep.call(wfi_image_model2, include_var_flat=True)
assert not hasattr(result1, "var_flat")
assert hasattr(result2, "var_flat")


@pytest.mark.parametrize(
"instrument",
[
Expand Down
4 changes: 3 additions & 1 deletion romancal/flux/flux_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def apply_flux_correction(model):
"""
# Define the various arrays to be converted.
DATA = ("data", "err")
VARIANCES = ("var_rnoise", "var_poisson", "var_flat")
VARIANCES = ("var_rnoise", "var_poisson")
if hasattr(model, "var_flat"):
VARIANCES = VARIANCES + ("var_flat",)

if model.meta.cal_step["flux"] == "COMPLETE":
message = (
Expand Down
8 changes: 3 additions & 5 deletions romancal/regtest/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,12 @@ def test_resample_single_file(rtdata, ignore_asdf_paths):
"err",
"var_poisson",
"var_rnoise",
"var_flat",
]
)
}"""
)
assert all(
hasattr(resample_out, x)
for x in ["data", "err", "var_poisson", "var_rnoise", "var_flat"]
hasattr(resample_out, x) for x in ["data", "err", "var_poisson", "var_rnoise"]
)

step.log.info(
Expand Down Expand Up @@ -94,14 +92,14 @@ def test_resample_single_file(rtdata, ignore_asdf_paths):
np.isnan(getattr(resample_out, x)),
np.equal(getattr(resample_out, x), 0)
)
) > 0 for x in ["var_poisson", "var_rnoise", "var_flat"]
) > 0 for x in ["var_poisson", "var_rnoise"]
)

}"""
)
assert all(
np.sum(np.isnan(getattr(resample_out, x)))
for x in ["var_poisson", "var_rnoise", "var_flat"]
for x in ["var_poisson", "var_rnoise"]
)

step.log.info(
Expand Down
28 changes: 16 additions & 12 deletions romancal/resample/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def __init__(

with self.input_models:
models = list(self.input_models)
self.all_have_var_flat = np.all(
[hasattr(model, "var_flat") for model in models]
)

# update meta.basic
populate_mosaic_basic(self.blank_output, models)
Expand Down Expand Up @@ -201,6 +204,9 @@ def __init__(
for i, m in enumerate(models):
self.input_models.shelve(m, i, modify=False)

if not self.all_have_var_flat:
del self.blank_output._instance["var_flat"]

def do_drizzle(self):
"""Pick the correct drizzling mode based on ``self.single``."""
if self.single:
Expand Down Expand Up @@ -355,6 +361,9 @@ def resample_many_to_one(self):
)

log.info("Resampling science data")

all_have_var_flat = True

with self.input_models:
for i, img in enumerate(self.input_models):
inwht = resample_utils.build_driz_weight(
Expand Down Expand Up @@ -396,22 +405,17 @@ def resample_many_to_one(self):
# Resample variances array in self.input_models to output_model
self.resample_variance_array("var_rnoise", output_model)
self.resample_variance_array("var_poisson", output_model)
self.resample_variance_array("var_flat", output_model)
if self.all_have_var_flat:
self.resample_variance_array("var_flat", output_model)

# Make exposure time image
exptime_tot = self.resample_exposure_time(output_model)

# TODO: fix unit here
output_model.err = np.sqrt(
np.nansum(
[
output_model.var_rnoise,
output_model.var_poisson,
output_model.var_flat,
],
axis=0,
)
)
all_vars = [output_model.var_rnoise, output_model.var_poisson]
if self.all_have_var_flat:
all_vars = all_vars + [output_model.var_flat]

output_model.err = np.sqrt(np.nansum(all_vars, axis=0))

self.update_exposure_times(output_model, exptime_tot)

Expand Down
22 changes: 22 additions & 0 deletions romancal/resample/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,28 @@ def test_update_exposure_times_same_sca_different_exposures(exposure_1, exposure
output_models.shelve(output_model, 0, modify=False)


@pytest.mark.parametrize("include_var_flat", [False, True])
def test_var_flat_presence(exposure_1, include_var_flat):
"""Test that var_flat is included or excluded depending on its presence in the underlying exposures."""
if not include_var_flat:
exposure_1 = [e.copy() for e in exposure_1]
for e in exposure_1:
del e._instance["var_flat"]
input_models = ModelLibrary(exposure_1)
resample_data = ResampleData(input_models)

output_models = resample_data.resample_many_to_one()
with output_models:
output_model = output_models.borrow(0)

if not include_var_flat:
assert not hasattr(output_model, "var_flat")
else:
assert hasattr(output_model, "var_flat")

output_models.shelve(output_model, 0, modify=False)


@pytest.mark.parametrize(
"name",
["var_rnoise", "var_poisson", "var_flat"],
Expand Down
Loading