Source code for jwst.assign_wcs.util

"""Utility functions for assign_wcs."""

import logging
import warnings

import numpy as np
from astropy.constants import c
from astropy.coordinates import SkyCoord
from astropy.modeling import models as astmodels
from astropy.table import QTable
from gwcs import WCS
from gwcs import utils as gwutils
from gwcs.wcstools import grid_from_bounding_box
from stcal.alignment.util import compute_s_region_imaging, compute_s_region_keyword
from stdatamodels.jwst.datamodels import MiriLRSSpecwcsModel, WavelengthrangeModel
from stdatamodels.jwst.transforms.models import GrismObject
from stpipe.exceptions import StpipeExitException

from jwst.lib.catalog_utils import SkyObject

log = logging.getLogger(__name__)


_MAX_SIP_DEGREE = 6


__all__ = [
    "velocity_correction",
    "MSAFileError",
    "NoDataOnDetectorError",
    "compute_scale",
    "calc_rotation_matrix",
    "wrap_ra",
    "update_fits_wcsinfo",
]


class MSAFileError(Exception):
    """Exception to raise when MSA shutter configuration file is missing or invalid."""

    def __init__(self, message):
        super(MSAFileError, self).__init__(message)


class NoDataOnDetectorError(StpipeExitException):
    """
    WCS solution indicates no data on detector.

    When WCS solutions are available, the solutions indicate that no data
    will be present, raise this exception.

    Specific example is for NIRSpec and the NRS2 detector. For various
    configurations of the MSA, it is possible that no dispersed spectra will
    appear on NRS2. This is not a failure of calibration, but needs to be
    called out in order for the calling architecture to be aware of this.
    """

    def __init__(self, message=None):
        if message is None:
            message = "WCS solution indicate that no science is in the data."
        # The first argument instructs stpipe CLI tools to exit with status
        # 64 when this exception is raised.
        super().__init__(64, message)


def compute_scale(
    wcs: WCS,
    fiducial: tuple | np.ndarray,
    disp_axis: int | None = None,
    pscale_ratio: float | None = None,
) -> float:
    """
    Compute scaling transform.

    Parameters
    ----------
    wcs : `~gwcs.wcs.WCS`
        Reference WCS object from which to compute a scaling factor.
    fiducial : tuple
        Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points.
    disp_axis : int
        Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction`
    pscale_ratio : int
        Ratio of output pixel scale to input pixel scale.

    Returns
    -------
    scale : float
        Scaling factor for x and y or cross-dispersion direction.
    """
    spectral = "SPECTRAL" in wcs.output_frame.axes_type

    if spectral and disp_axis is None:
        raise ValueError("If input WCS is spectral, a disp_axis must be given")

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", "invalid value", RuntimeWarning)
        crpix = np.array(wcs.invert(*fiducial, with_bounding_box=False))

    delta = np.zeros_like(crpix)
    spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0]
    delta[spatial_idx[0]] = 1

    crpix_with_offsets = np.vstack((crpix, crpix + delta, crpix + np.roll(delta, 1))).T
    crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False)

    coords = SkyCoord(
        ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg"
    )
    xscale: float = np.abs(coords[0].separation(coords[1]).value)
    yscale: float = np.abs(coords[0].separation(coords[2]).value)

    if pscale_ratio is not None:
        xscale *= pscale_ratio
        yscale *= pscale_ratio

    if spectral:
        # Assuming scale doesn't change with wavelength
        # Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction
        return yscale if disp_axis == 1 else xscale

    scale: float = np.sqrt(xscale * yscale)
    return scale


def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> list[float]:
    """
    Calculate the rotation matrix.

    Parameters
    ----------
    roll_ref : float
        Telescope roll angle of V3 North over East at the ref. point in radians
    v3i_yang : float
        The angle between ideal Y-axis and V3 in radians.
    vparity : int
        The x-axis parity, usually taken from the JWST SIAF parameter VIdlParity.
        Value should be "1" or "-1".

    Returns
    -------
    matrix : list
        The rotation matrix, [pc1_1, pc1_2, pc2_1, pc2_2]

    Notes
    -----
    The rotation is

       ----------------
       | pc1_1  pc2_1 |
       | pc1_2  pc2_2 |
       ----------------
    """
    if vparity not in (1, -1):
        raise ValueError(f"vparity should be 1 or -1. Input was: {vparity}")

    rel_angle = roll_ref - (vparity * v3i_yang)

    pc1_1 = vparity * np.cos(rel_angle)
    pc1_2 = np.sin(rel_angle)
    pc2_1 = vparity * -np.sin(rel_angle)
    pc2_2 = np.cos(rel_angle)

    return [pc1_1, pc1_2, pc2_1, pc2_2]


def subarray_transform(input_model):
    """
    Return an offset model if the observation uses a subarray.

    Parameters
    ----------
    input_model : JwstDataModel
        The input data model.

    Returns
    -------
    subarray2full : `~astropy.modeling.core.Model` or ``None``
        Returns a (combination of ) ``Shift`` models if a subarray is used.
        Returns ``None`` if a full frame observation.
    """
    tr_xstart = astmodels.Identity(1)
    tr_ystart = astmodels.Identity(1)

    # These quantities are 1-based
    xstart = input_model.meta.subarray.xstart
    ystart = input_model.meta.subarray.ystart

    if xstart is not None and xstart != 1:
        tr_xstart = astmodels.Shift(xstart - 1)

    if ystart is not None and ystart != 1:
        tr_ystart = astmodels.Shift(ystart - 1)

    if isinstance(tr_xstart, astmodels.Identity) and isinstance(tr_ystart, astmodels.Identity):
        # the case of a full frame observation
        return None
    else:
        subarray2full = tr_xstart & tr_ystart
        return subarray2full


def not_implemented_mode(input_model, ref, slit_y_range=None):  # noqa: ARG001
    """
    Send an error to the log and return None if assign_wcs has not been implemented for a mode.

    Parameters
    ----------
    input_model : JwstDataModel
        The input data model.
    ref : dict
        Mapping between reftype (keys) and reference file name (vals).
    slit_y_range : tuple
        The slit Y-range for Nirspec slits, relative to (0, 0) in the center.
    """
    exp_type = input_model.meta.exposure.type
    message = f"WCS for EXP_TYPE of {exp_type} is not implemented."
    log.critical(message)


def get_object_info(catalog_name=None):
    """
    Return a list of SkyObjects from the direct image.

    The source_catalog step catalog items are read into a list
    of  SkyObjects which can be referenced by catalog id. Only
    the columns needed by the WFSS code are saved.

    Parameters
    ----------
    catalog_name : str, astropy.table.table.Qtable
        The name of the photutils catalog or its quantities table

    Returns
    -------
    objects : list[jwst.transforms.models.SkyObject]
        A list of SkyObject tuples
    """
    if isinstance(catalog_name, str):
        if len(catalog_name) == 0:
            err_text = "Empty catalog filename"
            log.error(err_text)
            raise ValueError(err_text)
        try:
            catalog = QTable.read(catalog_name, format="ascii.ecsv")
        except FileNotFoundError as e:
            log.error(f"Could not find catalog file: {e}")
            raise FileNotFoundError(f"Could not find catalog: {e}") from None
    elif isinstance(catalog_name, QTable):
        catalog = catalog_name
    else:
        err_text = "Need to input string name of catalog or astropy.table.table.QTable instance"
        log.error(err_text)
        raise TypeError(err_text)

    objects = []

    # validate that the expected columns are there
    required_fields = set(SkyObject()._fields)

    try:
        if not set(required_fields).issubset(set(catalog.colnames)):
            difference = set(required_fields).difference(set(catalog.colnames))
            err_text = f"Missing required columns in source catalog: {difference}"
            log.error(err_text)
            raise KeyError(err_text)
    except AttributeError as e:
        err_text = f"Problem validating object catalog columns: {e}"
        log.error(err_text)
        raise AttributeError(err_text) from None

    # The columns are named sky_bbox_ll, sky_bbox_ul, sky_bbox_lr,
    # and sky_bbox_ur, each of which is a SkyCoord (i.e. RA & Dec & frame) at
    # one corner of the minimal bounding box. There will also be a sky_bbox
    # property as a 4-tuple of SkyCoord, but that is not serializable
    # (hence, the four separate columns).

    for row in catalog:
        objects.append(
            SkyObject(
                label=row["label"],
                xcentroid=row["xcentroid"],
                ycentroid=row["ycentroid"],
                sky_centroid=row["sky_centroid"],
                isophotal_abmag=row["isophotal_abmag"],
                isophotal_abmag_err=row["isophotal_abmag_err"],
                sky_bbox_ll=row["sky_bbox_ll"],
                sky_bbox_lr=row["sky_bbox_lr"],
                sky_bbox_ul=row["sky_bbox_ul"],
                sky_bbox_ur=row["sky_bbox_ur"],
                is_extended=row["is_extended"],
            )
        )
    return objects


def create_grism_bbox(
    input_model,
    reference_files=None,
    mmag_extract=None,
    extract_orders=None,
    wfss_extract_half_height=None,
    wavelength_range=None,
    nbright=None,
):
    """
    Create bounding boxes for each object in the catalog.

    The sky coordinates in the catalog image are first related
    to the grism image. They need to go through the WCS object
    in order to find the "direct image" pixel location, which is
    also in a detector pixel coordinate frame. This "direct image"
    location can then be sent through the trace polynomials to find
    the spectral location on the grism image for that wavelength and order.

    Parameters
    ----------
    input_model : ImageModel
        Data model which holds the grism image
    reference_files : dict, optional
        Dictionary of reference file names.
        If ``None``, ``wavelength_range`` must be supplied to specify
        the orders and corresponding wavelength ranges to be used in extraction.
    mmag_extract : float, optional
        The faintest magnitude to extract from the catalog.
    extract_orders : list, optional
        The list of orders to extract, if specified this will
        override the orders listed in the wavelengthrange reference file.
        If ``None``, the default one in the wavelengthrange reference file is used.
    wfss_extract_half_height : int, optional
        Cross-dispersion extraction half height in pixels, WFSS mode.
        Overwrites the computed extraction height in ``GrismObject.order_bounding.``
        If ``None``, it's computed from the segmentation map,
        using the min and max wavelength for each of the orders that
        are available.
    wavelength_range : dict, optional
        Pairs of {spectral_order: (wave_min, wave_max)} for each order.
        If ``None``, the default one in the wavelengthrange reference file is used.
    nbright : int, optional
        The number of brightest objects to extract from the catalog.

    Returns
    -------
    grism_objects : list
        A list of GrismObject(s) for every source in the catalog.
        Each grism object contains information about its
        spectral extent.

    Notes
    -----
    The wavelengthrange reference file is used to govern
    the extent of the bounding box for each object. The name of the
    catalog has been stored in the input models meta information under
    the source_catalog key.

    It's left to the calling routine to cut the bounding boxes at the
    extent of the detector (for example, extract 2d would only extract
    the on-detector portion of the bounding box)

    Bounding box dispersion direction is dependent on the filter and
    module for NIRCAM and changes for GRISMR, but is consistent for GRISMC,
    see https://jwst-docs.stsci.edu/display/JTI/NIRCam+Wide+Field+Slitless+Spectroscopy

    NIRISS has one detector.  GRISMC disperses along rows and
    GRISMR disperses along columns.

    If ``wfss_extract_half_height`` is specified it is used to compute the extent in
    the cross-dispersion direction, which becomes ``2 * wfss_extract_half_height + 1``.
    ``wfss_extract_half_height`` can only be applied to point source objects.
    """
    instr_name = input_model.meta.instrument.name
    if instr_name == "NIRCAM":
        filter_name = input_model.meta.instrument.filter
    elif instr_name == "NIRISS":
        filter_name = input_model.meta.instrument.pupil
    else:
        raise ValueError("create_grism_object works with NIRCAM and NIRISS WFSS exposures only.")

    if reference_files is None:
        # Get the list of extract_orders and lmin, lmax from wavelength_range.
        if wavelength_range is None:
            message = "If reference files are not supplied, ``wavelength_range`` must be provided."
            raise TypeError(message)
    else:
        # Get the list of extract_orders and lmin, lmax from the ``wavelengthrange`` reference file.
        with WavelengthrangeModel(reference_files["wavelengthrange"]) as f:
            if "WFSS" not in f.meta.exposure.type:
                err_text = "Wavelengthrange reference file not for WFSS"
                log.error(err_text)
                raise ValueError(err_text)
            ref_extract_orders = f.extract_orders
            if extract_orders is None:
                # ref_extract_orders = extract_orders
                extract_orders = [x[1] for x in ref_extract_orders if x[0] == filter_name].pop()

            wavelength_range = f.get_wfss_wavelength_range(filter_name, extract_orders)

    if mmag_extract is None:
        mmag_extract = 999.0  # extract all objects, regardless of magnitude
    else:
        log.info(f"Extracting objects < abmag = {mmag_extract}")
    if not isinstance(mmag_extract, (int, float)):
        raise TypeError(f"Expected mmag_extract to be a number, got {mmag_extract}")

    # extract the catalog objects
    if input_model.meta.source_catalog is None:
        err_text = "No source catalog listed in datamodel."
        log.error(err_text)
        raise ValueError(err_text)

    log.info(f"Getting objects from {input_model.meta.source_catalog}")

    return _create_grism_bbox(
        input_model, mmag_extract, wfss_extract_half_height, wavelength_range, nbright
    )


def _create_grism_bbox(
    input_model,
    mmag_extract=None,
    wfss_extract_half_height=None,
    wavelength_range=None,
    nbright=None,
):
    log.debug(f"Extracting with wavelength_range {wavelength_range}")

    # this contains the pure information from the catalog with no translations
    skyobject_list = get_object_info(input_model.meta.source_catalog)
    # get the imaging transform to record the center of the object in the image
    # here, image is in the imaging reference frame, before going through the
    # dispersion coefficients

    sky_to_detector = input_model.meta.wcs.get_transform("world", "detector")
    sky_to_grism = input_model.meta.wcs.backward_transform

    grism_objects = []  # the return list of GrismObjects
    for obj in skyobject_list:
        if obj.isophotal_abmag is None:
            continue
        if obj.isophotal_abmag >= mmag_extract:
            continue
        # could add logic to ignore object if too far off image,

        # save the image frame center of the object
        # takes in ra, dec, wavelength, order but wave and order
        # don't get used until the detector->grism_detector transform
        xcenter, ycenter, _, _ = sky_to_detector(
            obj.sky_centroid.icrs.ra.value, obj.sky_centroid.icrs.dec.value, 1, 1
        )

        order_bounding = {}
        waverange = {}
        partial_order = {}
        for order in wavelength_range:
            # range_select = [(x[2], x[3]) for x in wavelengthrange \
            # if (x[0] == order and x[1] == filter_name)]
            # The orders of the bounding box in the non-dispersed image
            # drive the extraction extent. The location of the min and
            # max wavelengths for each order are used to get the
            # location of the +/- sides of the bounding box in the
            # grism image
            lmin, lmax = wavelength_range[order]
            ra = np.array(
                [
                    obj.sky_bbox_ll.ra.value,
                    obj.sky_bbox_lr.ra.value,
                    obj.sky_bbox_ul.ra.value,
                    obj.sky_bbox_ur.ra.value,
                ]
            )
            dec = np.array(
                [
                    obj.sky_bbox_ll.dec.value,
                    obj.sky_bbox_lr.dec.value,
                    obj.sky_bbox_ul.dec.value,
                    obj.sky_bbox_ur.dec.value,
                ]
            )
            x1, y1, _, _, _ = sky_to_grism(ra, dec, [lmin] * 4, [order] * 4)
            x2, y2, _, _, _ = sky_to_grism(ra, dec, [lmax] * 4, [order] * 4)

            xstack = np.hstack([x1, x2])
            ystack = np.hstack([y1, y2])

            # Subarrays are only allowed in nircam tsgrism mode. The polynomial transforms
            # only work with the full frame coordinates.
            # The code here is called during extract_2d,
            # and is creating bounding boxes which should be in the full frame coordinates,
            # it just uses the input catalog and the magnitude
            # to limit the objects that need bounding boxes.

            # Tsgrism is always supposed to have the source object at the same pixel, and that is
            # hardcoded into the transforms.
            # At least a while ago, the 2d extraction for tsgrism mode
            # didn't call this bounding box code. So I think it's safe to leave the subarray
            # subtraction out, i.e. do not subtract x/ystart.

            xmin = np.nanmin(xstack)
            xmax = np.nanmax(xstack)
            ymin = np.nanmin(ystack)
            ymax = np.nanmax(ystack)

            if wfss_extract_half_height is not None and not obj.is_extended:
                if input_model.meta.wcsinfo.dispersion_direction == 2:
                    ra_center, dec_center = (
                        obj.sky_centroid.ra.value,
                        obj.sky_centroid.dec.value,
                    )
                    center, _, _, _, _ = sky_to_grism(
                        ra_center, dec_center, (lmin + lmax) / 2, order
                    )
                    xmin = center - wfss_extract_half_height
                    xmax = center + wfss_extract_half_height
                elif input_model.meta.wcsinfo.dispersion_direction == 1:
                    ra_center, dec_center = (
                        obj.sky_centroid.ra.value,
                        obj.sky_centroid.dec.value,
                    )
                    _, center, _, _, _ = sky_to_grism(
                        ra_center, dec_center, (lmin + lmax) / 2, order
                    )
                    ymin = center - wfss_extract_half_height
                    ymax = center + wfss_extract_half_height
                else:
                    raise ValueError("Cannot determine dispersion direction.")

            # Convert floating-point corner values to whole pixel indexes
            xmin = gwutils._toindex(xmin)  # noqa: SLF001
            xmax = gwutils._toindex(xmax)  # noqa: SLF001
            ymin = gwutils._toindex(ymin)  # noqa: SLF001
            ymax = gwutils._toindex(ymax)  # noqa: SLF001

            # Don't add objects and orders that are entirely off the detector.
            # "partial_order" marks objects that are near enough to the detector
            # edge to have some spectrum on the detector.
            # This is useful because the catalog often is created from a resampled direct
            # image that is bigger than the detector FOV for a single grism exposure.
            exclude = False
            ispartial = False

            # Here we check to ensure that the extraction region `pts`
            # has at least two pixels of width in the dispersion
            # direction, and one in the cross-dispersed direction when
            # placed into the subarray extent.
            pts = np.array([[ymin, xmin], [ymax, xmax]])
            subarr_extent = np.array(
                [
                    [0, 0],
                    [
                        input_model.meta.subarray.ysize - 1,
                        input_model.meta.subarray.xsize - 1,
                    ],
                ]
            )

            if input_model.meta.wcsinfo.dispersion_direction == 1:
                # X-axis is dispersion direction
                disp_col = 1
                xdisp_col = 0
            else:
                # Y-axis is dispersion direction
                disp_col = 0
                xdisp_col = 1

            dispaxis_check = (pts[1, disp_col] - subarr_extent[0, disp_col] > 0) and (
                subarr_extent[1, disp_col] - pts[0, disp_col] > 0
            )
            xdispaxis_check = (pts[1, xdisp_col] - subarr_extent[0, xdisp_col] >= 0) and (
                subarr_extent[1, xdisp_col] - pts[0, xdisp_col] >= 0
            )

            contained = dispaxis_check and xdispaxis_check

            inidx = np.all(np.logical_and(subarr_extent[0] <= pts, pts <= subarr_extent[1]), axis=1)

            if not contained:
                exclude = True
                log.info(f"Excluding off-image object: {obj.label}, order {order}")
            elif contained >= 1:
                outbox = pts[np.logical_not(inidx)]
                if len(outbox) > 0:
                    ispartial = True
                    log.info(f"Partial order on detector for obj: {obj.label} order: {order}")

            if not exclude:
                order_bounding[order] = ((ymin, ymax), (xmin, xmax))
                waverange[order] = (lmin, lmax)
                partial_order[order] = ispartial

        if len(order_bounding) > 0:
            grism_objects.append(
                GrismObject(
                    sid=obj.label,
                    order_bounding=order_bounding,
                    sky_centroid=obj.sky_centroid,
                    partial_order=partial_order,
                    waverange=waverange,
                    sky_bbox_ll=obj.sky_bbox_ll,
                    sky_bbox_lr=obj.sky_bbox_lr,
                    sky_bbox_ul=obj.sky_bbox_ul,
                    sky_bbox_ur=obj.sky_bbox_ur,
                    xcentroid=xcenter,
                    ycentroid=ycenter,
                    is_extended=obj.is_extended,
                    isophotal_abmag=obj.isophotal_abmag,
                )
            )

    # At this point we have a list of grism objects limited to
    # isophotal_abmag < mmag_extract. We now need to further restrict
    # the list to the N brightest objects, as given by nbright.
    if nbright is None:
        # Include all objects, regardless of brightness
        final_objects = grism_objects
    else:
        # grism_objects is a list of objects, so it's not easy or practical
        # to sort it directly. So create a list of the isophotal_abmags, which
        # we'll then use to find the N brightest objects.
        indxs = np.argsort([obj.isophotal_abmag for obj in grism_objects])

        # Create a final grism object list containing only the N brightest objects
        final_objects = []
        final_objects = [grism_objects[i] for i in indxs[:nbright]]
        del grism_objects

    log.info(f"Total of {len(final_objects)} grism objects defined")
    if len(final_objects) == 0:
        log.warning("No grism objects saved; check catalog or step params")

    return final_objects


def transform_bbox_from_shape(shape, order="C"):
    """
    Create a bounding box from the shape of the data.

    This is appropriate to attached to a transform.

    Parameters
    ----------
    shape : tuple
        The shape attribute from a `numpy.ndarray` array
    order : str
        The order of the array.  Either "C" or "F".

    Returns
    -------
    bbox : tuple
        Bounding box in y, x order if order is "C" (default)
        Boundsing box in x, y order if order is "F"
    """
    bbox = ((-0.5, shape[-2] - 0.5), (-0.5, shape[-1] - 0.5))

    return bbox if order == "C" else bbox[::-1]


def wcs_bbox_from_shape(shape):
    """
    Create a bounding box from the shape of the data.

    This is appropriate to attach to a wcs object

    Parameters
    ----------
    shape : tuple
        The shape attribute from a `numpy.ndarray` array

    Returns
    -------
    bbox : tuple
        Bounding box in x, y order.
    """
    bbox = ((-0.5, shape[-1] - 0.5), (-0.5, shape[-2] - 0.5))
    return bbox


def bounding_box_from_subarray(input_model, order="C"):
    """
    Create a bounding box from the subarray size.

    Note: The bounding_box assumes full frame coordinates.
    It is set to ((ystart, ystart + xsize), (xstart, xstart + xsize)).
    It is in 0-based coordinates.

    Parameters
    ----------
    input_model : JwstDataModel
        The input data model.
    order : str
        The order of the array.  Either "C" or "F".

    Returns
    -------
    bbox : tuple
        Bounding box in y, x order if order is "C" (default)
        Boundsing box in x, y order if order is "F"
    """
    bb_xstart = -0.5
    bb_xend = -0.5
    bb_ystart = -0.5
    bb_yend = -0.5

    if input_model.meta.subarray.xsize is not None:
        bb_xend = input_model.meta.subarray.xsize - 0.5
    if input_model.meta.subarray.ysize is not None:
        bb_yend = input_model.meta.subarray.ysize - 0.5

    bbox = ((bb_ystart, bb_yend), (bb_xstart, bb_xend))
    return bbox if order == "C" else bbox[::-1]


def update_s_region_imaging(model):
    """Update the ``S_REGION`` keyword using ``WCS.footprint``."""
    s_region = compute_s_region_imaging(model.meta.wcs, shape=model.data.shape, center=False)
    if s_region is not None:
        model.meta.wcsinfo.s_region = s_region


def update_s_region_lrs(model, reference_files):
    """
    Update ``S_REGION`` using V2,V3 of the slit corners from reference file.

    s_region for model is updated in place.

    Parameters
    ----------
    model : DataModel
        Input model
    reference_files : list
        List of reference files for assign_wcs.
    """
    refmodel = MiriLRSSpecwcsModel(reference_files["specwcs"])

    v2vert1 = refmodel.meta.v2_vert1
    v2vert2 = refmodel.meta.v2_vert2
    v2vert3 = refmodel.meta.v2_vert3
    v2vert4 = refmodel.meta.v2_vert4

    v3vert1 = refmodel.meta.v3_vert1
    v3vert2 = refmodel.meta.v3_vert2
    v3vert3 = refmodel.meta.v3_vert3
    v3vert4 = refmodel.meta.v3_vert4

    refmodel.close()
    v2 = [v2vert1, v2vert2, v2vert3, v2vert4]
    v3 = [v3vert1, v3vert2, v3vert3, v3vert4]

    if any(elem is None for elem in v2) or any(elem is None for elem in v3):
        log.info("The V2,V3 coordinates of the MIRI LRS-Fixed slit contains NaN values.")
        log.info("The s_region will not be updated")

    lam = 7.0  # wavelength does not matter for s region so just assign a value in range of LRS
    s = model.meta.wcs.transform("v2v3", "world", v2, v3, lam)
    a = s[0]
    b = s[1]
    footprint = np.array([[a[0], b[0]], [a[1], b[1]], [a[2], b[2]], [a[3], b[3]]])

    update_s_region_keyword(model, footprint)


def compute_footprint_spectral(model):
    """
    Determine spatial footprint for spectral observations using the instrument model.

    Parameters
    ----------
    model : DataModel
        The output of assign_wcs.

    Returns
    -------
    footprint : ndarray
        The spatial footprint of the observation.
    spectral_region : tuple
        The wavelength range for the observation.
    """
    swcs = model.meta.wcs
    bbox = swcs.bounding_box
    if bbox is None:
        bbox = wcs_bbox_from_shape(model.data.shape)

    x, y = grid_from_bounding_box(bbox)
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", "invalid value", RuntimeWarning)
        ra, dec, lam = swcs(x, y)

    # the wrapped ra values are forced to be on one side of ra-border
    # the wrapped ra are used to determine the correct  min and max ra
    ra = wrap_ra(ra)
    min_ra = np.nanmin(ra)
    max_ra = np.nanmax(ra)

    # for the footprint we want the ra values to fall between 0 to 360
    if min_ra < 0:
        min_ra = min_ra + 360.0
    if max_ra >= 360.0:
        max_ra = max_ra - 360.0
    footprint = np.array(
        [
            [min_ra, np.nanmin(dec)],
            [max_ra, np.nanmin(dec)],
            [max_ra, np.nanmax(dec)],
            [min_ra, np.nanmax(dec)],
        ]
    )
    lam_min = np.nanmin(lam)
    lam_max = np.nanmax(lam)
    return footprint, (lam_min, lam_max)


def update_s_region_spectral(model):
    """Update the S_REGION keyword."""
    footprint, spectral_region = compute_footprint_spectral(model)
    update_s_region_keyword(model, footprint)
    model.meta.wcsinfo.spectral_region = spectral_region


def compute_footprint_nrs_slit(slit):
    """
    Compute the footprint of a NIRSpec slit using the instrument model.

    Parameters
    ----------
    slit : `~jwst.datamodels.SlitModel`
        The slit model.

    Returns
    -------
    footprint : ndarray
        The spatial footprint
    spectral_region : tuple
        The wavelength range for the observation.
    """
    slit2world = slit.meta.wcs.get_transform("slit_frame", "world")
    # Define the corners of a virtual slit. The center of the slit is (0, 0).
    virtual_corners_x = [-0.5, -0.5, 0.5, 0.5]
    virtual_corners_y = [slit.slit_ymin, slit.slit_ymax, slit.slit_ymax, slit.slit_ymin]
    # Use a default wavelength or 2 microns as input to the transform.
    input_lam = [2e-6] * 4
    ra, dec, lam = slit2world(virtual_corners_x, virtual_corners_y, input_lam)
    footprint = np.array([ra, dec]).T
    lam_min = np.nanmin(lam)
    lam_max = np.nanmax(lam)
    return footprint, (lam_min, lam_max)


def update_s_region_nrs_slit(slit):
    footprint, spectral_region = compute_footprint_nrs_slit(slit)
    update_s_region_keyword(slit, footprint)
    slit.meta.wcsinfo.spectral_region = spectral_region


def update_s_region_keyword(model, footprint):
    """Update the S_REGION keyword."""
    s_region = compute_s_region_keyword(footprint)
    if s_region is not None:
        model.meta.wcsinfo.s_region = s_region


def compute_footprint_nrs_ifu(dmodel):
    """
    Determine NIRSPEC IFU footprint using the instrument model.

    Parameters
    ----------
    dmodel : `~jwst.datamodels.IFUImageModel`
        The output of assign_wcs.

    Returns
    -------
    footprint : ndarray
        The spatial footprint
    spectral_region : tuple
        The wavelength range for the observation.
    """
    ra_total = []
    dec_total = []
    lam_total = []

    for slit in range(30):
        x, y = grid_from_bounding_box(dmodel.meta.wcs.bounding_box[slit])
        ra, dec, lam, _ = dmodel.meta.wcs(x, y, slit)
        ra_total.extend(np.ravel(ra))
        dec_total.extend(np.ravel(dec))
        lam_total.extend(np.ravel(lam))

    # the wrapped ra values are forced to be on one side of ra-border
    # the wrapped ra are used to determine the correct  min and max ra
    ra_total = wrap_ra(ra_total)
    ra_max = np.nanmax(ra_total)
    ra_min = np.nanmin(ra_total)
    # for the footprint we want ra to be between 0 to 360
    if ra_min < 0:
        ra_min = ra_min + 360.0
    if ra_max >= 360.0:
        ra_max = ra_max - 360.0

    dec_max = np.nanmax(dec_total)
    dec_min = np.nanmin(dec_total)
    lam_max = np.nanmax(lam_total)
    lam_min = np.nanmin(lam_total)
    footprint = np.array([ra_min, dec_min, ra_max, dec_min, ra_max, dec_max, ra_min, dec_max])
    return footprint, (lam_min, lam_max)


def update_s_region_nrs_ifu(output_model):
    """
    Update S_REGION for NRS_IFU observations using calculated footprint.

    Parameters
    ----------
    output_model : `~jwst.datamodels.IFUImageModel`
        The output of assign_wcs.
    """
    footprint, spectral_region = compute_footprint_nrs_ifu(output_model)
    update_s_region_keyword(output_model, footprint)
    output_model.meta.wcsinfo.spectral_region = spectral_region


def update_s_region_mrs(output_model):
    """
    Update S_REGION for MIRI_MRS observations using the WCS transforms.

    Parameters
    ----------
    output_model : `~jwst.datamodels.IFUImageModel`
        The output of assign_wcs.
    """
    footprint, spectral_region = compute_footprint_spectral(output_model)
    update_s_region_keyword(output_model, footprint)
    output_model.meta.wcsinfo.spectral_region = spectral_region


def velocity_correction(velosys):
    """
    Compute wavelength correction to Barycentric reference frame.

    Parameters
    ----------
    velosys : float
        Radial velocity wrt Barycenter [m / s].

    Returns
    -------
    model : `astropy.modeling.Model`
        The velocity correction model.
    """
    correction = 1 / (1 + velosys / c.value)
    model = astmodels.Identity(1) * astmodels.Const1D(correction, name="velocity_correction")
    model.inverse = astmodels.Identity(1) / astmodels.Const1D(correction, name="inv_vel_correction")

    return model


def wrap_ra(ravalues):
    """
    Test for 0/360 wrapping in RA values.

    If exists it makes it difficult to determine
    RA range of a region on the sky. This problem is solved by putting them all
    on "one side" of 0/360 border

    Parameters
    ----------
    ravalues : numpy.ndarray
        The input RA values

    Returns
    -------
    np.ndarray
        A numpy array of RA values all on "same side" of 0/360 border
    """
    ravalues_array = np.array(ravalues)
    index_good = np.where(np.isfinite(ravalues_array))
    ravalues_wrap = ravalues_array[index_good].copy()
    median_ra = np.nanmedian(ravalues_wrap)

    # using median to test if there is any wrapping going on
    wrap_index = np.where(np.fabs(ravalues_wrap - median_ra) > 180.0)
    nwrap = wrap_index[0].size

    # get all the ra on the same "side" of 0/360
    if nwrap != 0 and median_ra < 180:
        ravalues_wrap[wrap_index] = ravalues_wrap[wrap_index] - 360.0

    if nwrap != 0 and median_ra > 180:
        ravalues_wrap[wrap_index] = ravalues_wrap[wrap_index] + 360.0

    # if the input ravaules are a list - return a list
    if isinstance(ravalues, list):
        ravalues = ravalues_wrap.tolist()

    return ravalues_wrap


def in_ifu_slice(slice_wcs, ra, dec, lam):
    """
    Given RA, DEC and LAM return the x, y positions within a slice.

    Parameters
    ----------
    slice_wcs : `~gwcs.wcs.WCS`
        Slice WCS object.
    ra, dec, lam : float, ndarray
        Physical Coordinates.

    Returns
    -------
    x, y : float, ndarray
        The x, y locations within the slice.
    """
    slicer2world = slice_wcs.get_transform("slicer", "world")
    slx, sly, sllam = slicer2world.inverse(ra, dec, lam)

    # Compute the slice X coordinate using the center of the slit.
    slx_center, _, _ = slice_wcs.get_transform("slit_frame", "slicer")(0, 0, 2e-6)
    onslice_ind = np.isclose(slx, slx_center, atol=5e-4)

    return onslice_ind


[docs] def update_fits_wcsinfo( datamodel, max_pix_error=0.01, degree=None, max_inv_pix_error=0.01, inv_degree=None, npoints=12, crpix=None, projection="TAN", imwcs=None, **kwargs, ): """ Update ``datamodel.meta.wcsinfo`` based on a FITS WCS + SIP approximation of a GWCS object. By default, this function will approximate the datamodel's GWCS object stored in ``datamodel.meta.wcs`` but it can also approximate a user-supplied GWCS object when provided via the ``imwcs`` parameter. The default mode in using this attempts to achieve roughly 0.01 pixel accuracy over the entire image. This function uses the :py:meth:`~gwcs.wcs.WCS.to_fits_sip` to create FITS WCS representations of GWCS objects. Only most important :py:meth:`~gwcs.wcs.WCS.to_fits_sip` parameters are exposed here. Other arguments to :py:meth:`~gwcs.wcs.WCS.to_fits_sip` can be passed via ``kwargs`` - see "Other Parameters" section below. Please refer to the documentation of :py:meth:`~gwcs.wcs.WCS.to_fits_sip` for more details. .. warning:: This function modifies input data model's ``datamodel.meta.wcsinfo`` members. Parameters ---------- datamodel : `~jwst.datamodels.ImageModel` The input data model for imaging or WFSS mode whose ``meta.wcsinfo`` field should be updated from GWCS. By default, ``datamodel.meta.wcs`` is used to compute FITS WCS + SIP approximation. When ``imwcs`` is not `None` then computed FITS WCS will be an approximation of the WCS provided through the ``imwcs`` parameter. max_pix_error : float, optional Maximum allowed error over the domain of the pixel array. This error is the equivalent pixel error that corresponds to the maximum error in the output coordinate resulting from the fit based on a nominal plate scale. degree : int, iterable, None, optional Degree of the SIP polynomial. Default value `None` indicates that all allowed degree values (``[1...6]``) will be considered and the lowest degree that meets accuracy requerements set by ``max_pix_error`` will be returned. Alternatively, ``degree`` can be an iterable containing allowed values for the SIP polynomial degree. This option is similar to default `None` but it allows caller to restrict the range of allowed SIP degrees used for fitting. Finally, ``degree`` can be an integer indicating the exact SIP degree to be fit to the WCS transformation. In this case ``max_pixel_error`` is ignored. max_inv_pix_error : float, None, optional Maximum allowed inverse error over the domain of the pixel array in pixel units. With the default value of `None` no inverse is generated. inv_degree : int, iterable, None, optional Degree of the SIP polynomial. Default value `None` indicates that all allowed degree values (``[1...6]``) will be considered and the lowest degree that meets accuracy requerements set by ``max_pix_error`` will be returned. Alternatively, ``degree`` can be an iterable containing allowed values for the SIP polynomial degree. This option is similar to default `None` but it allows caller to restrict the range of allowed SIP degrees used for fitting. Finally, ``degree`` can be an integer indicating the exact SIP degree to be fit to the WCS transformation. In this case ``max_inv_pixel_error`` is ignored. npoints : int, optional The number of points in each dimension to sample the bounding box for use in the SIP fit. Minimum number of points is 3. crpix : list of float, None, optional Coordinates (1-based) of the reference point for the new FITS WCS. When not provided, i.e., when set to `None` (default) the reference pixel already specified in ``wcsinfo`` will be reused. If ``wcsinfo`` does not contain ``crpix`` information, then the reference pixel will be chosen near the center of the bounding box for axes corresponding to the celestial frame. projection : str, `~astropy.modeling.projections.Pix2SkyProjection`, optional Projection to be used for the created FITS WCS. It can be specified as a string of three characters specifying a FITS projection code from Table 13 in `Representations of World Coordinates in FITS \ <https://doi.org/10.1051/0004-6361:20021326>`_ (Paper I), Greisen, E. W., and Calabretta, M. R., A & A, 395, 1061-1075, 2002. Alternatively, it can be an instance of one of the `astropy's Pix2Sky_* <https://docs.astropy.org/en/stable/modeling/\ reference_api.html#module-astropy.modeling.projections>`_ projection models inherited from :py:class:`~astropy.modeling.projections.Pix2SkyProjection`. imwcs : `gwcs.wcs.WCS`, None, optional Imaging GWCS object for WFSS mode whose FITS WCS approximation should be computed and stored in the ``datamodel.meta.wcsinfo`` field. When ``imwcs`` is `None` then WCS from ``datamodel.meta.wcs`` will be used. .. warning:: Used with WFSS modes only. For other modes, supplying a different WCS from ``datamodel.meta.wcs`` will result in the GWCS and FITS WCS descriptions to diverge. **kwargs : dict, optional Additional parameters to be passed to :py:meth:`~gwcs.wcs.WCS.to_fits_sip`. These may include: * bounding_box : tuple, None, optional A pair of tuples, each consisting of two numbers Represents the range of pixel values in both dimensions ((xmin, xmax), (ymin, ymax)) * verbose : bool, optional Print progress of fits. Returns ------- `~astropy.io.fits.Header` FITS header with all SIP WCS keywords Raises ------ ValueError If the WCS is not at least 2D, an exception will be raised. If the specified accuracy (both forward and inverse, both rms and maximum) is not achieved an exception will be raised. Notes ----- Use of this requires a judicious choice of required accuracies. Attempts to use higher degrees (~7 or higher) will typically fail due to floating point problems that arise with high powers. For more details, see :py:meth:`~gwcs.wcs.WCS.to_fits_sip`. """ if crpix is None: crpix = [datamodel.meta.wcsinfo.crpix1, datamodel.meta.wcsinfo.crpix2] if None in crpix: crpix = None # For WFSS modes the imaging WCS is passed as an argument. # For imaging modes it is retrieved from the datamodel. if imwcs is None: imwcs = datamodel.meta.wcs # limit default 'degree' ranges to _MAX_SIP_DEGREE: if degree is None: degree = range(1, _MAX_SIP_DEGREE) if inv_degree is None: inv_degree = range(1, _MAX_SIP_DEGREE) hdr = imwcs.to_fits_sip( max_pix_error=max_pix_error, degree=degree, max_inv_pix_error=max_inv_pix_error, inv_degree=inv_degree, npoints=npoints, crpix=crpix, projection=projection, **kwargs, ) # update meta.wcsinfo with FITS keywords except for naxis* del hdr["naxis*"] # maintain convention of lowercase keys hdr_dict = {k.lower(): v for k, v in hdr.items()} # delete naxis, cdelt, pc from wcsinfo rm_keys = [ "naxis", "cdelt1", "cdelt2", "pc1_1", "pc1_2", "pc2_1", "pc2_2", "a_order", "b_order", "ap_order", "bp_order", ] rm_keys.extend( f"{s}_{i}_{j}" for i in range(10) for j in range(10) for s in ["a", "b", "ap", "bp"] ) for key in rm_keys: if key in datamodel.meta.wcsinfo.instance: del datamodel.meta.wcsinfo.instance[key] # update meta.wcs_info with fit keywords datamodel.meta.wcsinfo.instance.update(hdr_dict) return hdr
def wfss_imaging_wcs(wfss_model, imaging, bbox=None, **kwargs): """ Add a FITS WCS approximation for imaging mode to WFSS headers. Parameters ---------- wfss_model : `~ImageModel` Input WFSS model (NRC or NIS). imaging : func, callable The ``imaging`` function in the ``niriss`` or ``nircam`` modules. bbox : tuple or None The bounding box over which to approximate the distortion solution. Typically this is based on the shape of the direct image. **kwargs : dict Additional parameters to be passed to update_fits_wcsinfo(). """ xstart = wfss_model.meta.subarray.xstart ystart = wfss_model.meta.subarray.ystart reference_files = get_wcs_reference_files(wfss_model) image_pipeline = imaging(wfss_model, reference_files) imwcs = WCS(image_pipeline) if bbox is not None: imwcs.bounding_box = bbox elif xstart is not None and ystart is not None and (xstart != 1 or ystart != 1): imwcs.bounding_box = bounding_box_from_subarray(wfss_model) else: imwcs.bounding_box = wcs_bbox_from_shape(wfss_model.data.shape) _ = update_fits_wcsinfo(wfss_model, projection="TAN", imwcs=imwcs, bounding_box=None, **kwargs) def get_wcs_reference_files(datamodel): """ Retrieve names of WCS reference files for NIS_WFSS and NRC_WFSS modes. Parameters ---------- datamodel : ImageModel Input WFSS file (NRC or NIS). Returns ------- dict Mapping between reftype (keys) and reference file name (vals). """ from jwst.assign_wcs import AssignWcsStep refs = {} step = AssignWcsStep() for reftype in AssignWcsStep.reference_file_types: val = step.get_reference_file(datamodel, reftype) if val.strip() == "N/A": refs[reftype] = None else: refs[reftype] = val return refs