"""Flag pixels affected by open MSA shutters in NIRSpec exposures."""
import json
import logging
import warnings
from pathlib import Path
import numpy as np
from gwcs.wcs import WCS
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.transforms.models import Slit
from jwst.assign_wcs.nirspec import (
generate_compound_bbox,
slitlets_wcs,
)
from jwst.assign_wcs.nirspec import (
log as nirspec_log,
)
from jwst.lib.basic_utils import LoggingContext
log = logging.getLogger(__name__)
FAILEDOPENFLAG = datamodels.dqflags.pixel["MSA_FAILED_OPEN"]
SHUTTERS_PER_ROW = 365
# States in the msaoper file that are flagged when set to 'open'
FLAGGABLE_STATES = ["Internal state", "TA state", "state"]
__all__ = [
"do_correction",
"flag",
"boundingbox_to_indices",
"wcs_to_dq",
"get_failed_open_shutters",
"create_slitlets",
]
[docs]
def do_correction(input_model, shutter_refname, wcs_refnames):
"""
Apply DQ flag to pixels affected by failed open MSA shutters.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.ImageModel`
Science data to be corrected. Updated in-place.
shutter_refname : str
Name of MSAOPER reference file.
wcs_refnames : dict
Dictionary of WCS reference file names.
Returns
-------
input_model : `~stdatamodels.jwst.datamodels.ImageModel`
Science data with DQ array modified.
"""
# Create a list of failed open slitlets from the msaoper reference file
failed_slitlets = create_slitlets(shutter_refname)
log.info("%d failed open shutters", len(failed_slitlets))
# Flag the stuck open shutters
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Invalid interval")
input_model = flag(input_model, failed_slitlets, wcs_refnames)
input_model.meta.cal_step.msa_flagging = "COMPLETE"
return input_model
[docs]
def flag(input_datamodel, failed_slitlets, wcs_refnames):
"""
Flag slitlet regions for failed open shutters.
Takes the list of failed open shutters from the failedopen reference file
and calculates the pixels affected using the WCS model.
The affected pixels in the science data have their DQ flags combined with
that for the MSA_FAILED_OPEN standard flag. All other science data
arrays are unchanged.
The input datamodel is modified in-place.
Parameters
----------
input_datamodel : `~stdatamodels.jwst.datamodels.JwstDataModel`
Input science data. Updated in-place.
failed_slitlets : list of `~stdatamodels.jwst.transforms.Slit`
Failed open slitlets.
wcs_refnames : dict
Reference file names used to calculate the WCS. Keys are reference
file types; values are file paths.
Returns
-------
input_datamodel : `~stdatamodels.jwst.datamodels.JwstDataModel`
Science data with DQ flags modified.
"""
# Use the machinery in assign_wcs to create a WCS object for the bad shutters
with LoggingContext(nirspec_log, level=logging.WARNING):
pipeline = slitlets_wcs(input_datamodel, wcs_refnames, failed_slitlets)
wcs = WCS(pipeline)
# Create a copy of the input model's metadata, so we can overwrite
# the wcs with the wcs for the failed open shutters.
# We need to use the slit WCS for this even if the input EXP_TYPE is
# NRS_IFU because we are calculating where stuck open slits affect the data.
meta_model = datamodels.ImageModel()
meta_model.meta.wcs = wcs
meta_model.meta.wcsinfo = input_datamodel.meta.wcsinfo
meta_model.meta.exposure.type = "NRS_MSASPEC"
meta_model.meta.wcs.bounding_box = generate_compound_bbox(meta_model, failed_slitlets)
dq_array = input_datamodel.dq
for slitlet in failed_slitlets:
# Convert the bounding box for this slitlet to a set of indices to use as a slice
bbox = meta_model.meta.wcs.bounding_box[slitlet.name]
xmin, xmax, ymin, ymax = boundingbox_to_indices(input_datamodel.data.shape, bbox)
# Make a grid of points within the slice
y_indices, x_indices = np.mgrid[ymin:ymax, xmin:xmax]
# Calculate the arrays of coordinates for each pixel in the slice
ra, dec, lam, _ = meta_model.meta.wcs(x_indices, y_indices, slitlet.name)
# The coordinate_array is a tuple of arrays, one for each output coordinate
# In this case there should be 3 arrays, one each for RA, Dec and Wavelength
# For pixels outside the slitlet, the arrays have NaN
# Make a subarray from these coordinate arrays by setting pixels that aren't
# NaN to FAILEDOPENFLAG, the rest to 0
dq_subarray = wcs_to_dq((ra, dec, lam), FAILEDOPENFLAG)
# Bitwise-or this subarray with the slice in the original exposure's DQ array
dq_array[..., ymin:ymax, xmin:xmax] |= dq_subarray
# Set the dq array of the input datamodel to the corrected dq array
input_datamodel.dq = dq_array
return input_datamodel
[docs]
def boundingbox_to_indices(data_shape, bounding_box):
"""
Translate a bounding box to image indices.
Takes a ``bounding_box`` (tuple of tuples: ``((x1, x2), (y1, y2))``) and
a datamodel and calculates the range of indices in the X and Y dimensions
of the overlap between the bounding box and the datamodel's data array.
Parameters
----------
data_shape : tuple
The data shape for the input science datamodel.
bounding_box : tuple of tuple
Bounding box returned from WCS object.
Returns
-------
xmin, xmax, ymin, ymax : int
Range of indices of overlap between science data array and bounding box.
"""
nrows, ncols = data_shape[-2:]
x1, x2 = bounding_box[0]
y1, y2 = bounding_box[1]
xmin = int(min(x1, x2))
xmin = max(xmin, 0)
xmax = int(max(x1, x2)) + 1
xmax = min(xmax, ncols)
ymin = int(min(y1, y2))
ymin = max(ymin, 0)
ymax = int(max(y1, y2)) + 1
ymax = min(ymax, nrows)
return xmin, xmax, ymin, ymax
[docs]
def wcs_to_dq(wcs_array, flag):
"""
Create a DQ subarray corresponding to a failed open slitlet.
The created array has the value ``flag`` wherever the WCS coordinates
are valid (non-NaN) and 0 otherwise.
Parameters
----------
wcs_array : tuple of ndarray
Image coordinates for the failed open region.
flag : int
DQ flag to set.
Returns
-------
dq : ndarray of int
Output DQ array.
"""
dq = np.zeros((wcs_array[0].shape), dtype=np.uint32)
non_nan = ~np.isnan(wcs_array[0])
dq[non_nan] = flag
return dq
[docs]
def get_failed_open_shutters(shutter_refname):
"""
Get the failed open shutters from a reference file.
Parameters
----------
shutter_refname : str
File name for the MSAOPER reference file.
Returns
-------
failedopen : list
A list of shutters which satisfy the condition that at
least one of the states in FLAGGABLE_STATES is set to 'open'.
"""
# Read the bad shutter reference file data model
with Path(shutter_refname).open("r") as f1:
shutters = json.load(f1)
failedopen = []
for shutter in shutters["msaoper"]:
for state in FLAGGABLE_STATES:
if shutter[state] == "open":
failedopen.append(shutter)
break
return failedopen
[docs]
def create_slitlets(shutter_refname):
"""
Create slitlets for each failed open shutter.
For the created slit objects, "shutter_id" is an integer that uniquely defines
the shutter in the quadrant, calculated from the x and y center of the shutter.
In the Slit tuple, the only values that matter are "name" (must be unique),
xcen, ycen, quadrant (from msaoper file), ymin, ymax (should be -0.5, 0.5), and
shutter state (should be 'x', for one open shutter). Default values are assigned
for all other values.
Returns
-------
slitlets : list of `~stdatamodels.jwst.transforms.Slit`
A list of slitlets. Each slitlet is a named tuple with elements
("name", "shutter_id", "dither_position", "xcen", "ycen", "ymin", "ymax",
"quadrant", "source_id", "shutter_state", "source_name", "source_alias",
"stellarity", "source_xpos", "source_ypos", "source_ra", "source_dec").
"""
failedopenlist = get_failed_open_shutters(shutter_refname)
slitlets = []
counter = 0
for shutter in failedopenlist:
counter = counter + 1
x = shutter["x"]
y = shutter["y"]
shutter_id = x + (y - 1) * SHUTTERS_PER_ROW
slitlets.append(
Slit(counter, shutter_id, 0, x, y, -0.5, 0.5, shutter["Q"], 0, "x", slit_id=counter)
)
return slitlets