Source code for jwst.extract_1d.psf_profile

import logging

import numpy as np
from scipy import ndimage, optimize
from stdatamodels.jwst.datamodels import SpecPsfModel

from jwst.extract_1d.extract1d import extract1d
from jwst.extract_1d.source_location import middle_from_wcs, nod_pair_location, trace_from_wcs

__all__ = ["psf_profile"]

log = logging.getLogger(__name__)

HORIZONTAL = 1
VERTICAL = 2
"""Dispersion direction, predominantly horizontal or vertical."""

NOD_PAIR_PATTERN = ["ALONG-SLIT-NOD", "2-POINT-NOD"]


def open_psf(psf_refname, exp_type):
    """
    Open the PSF reference file.

    Parameters
    ----------
    psf_refname : str
        The name of the psf reference file.
    exp_type : str
        The exposure type of the data.

    Returns
    -------
    psf_model : SpecPsfModel
        Returns the EPSF model.
    """
    if exp_type == "MIR_LRS-FIXEDSLIT":
        # The information we read in from PSF file is:
        # center_col: psf_model.meta.psf.center_col
        # super sample factor: psf_model.meta.psf.subpix)
        # psf : psf_model.data (2d)
        # wavelength of PSF planes: psf_model.wave
        psf_model = SpecPsfModel(psf_refname)

    else:
        # So far, only MIRI LRS has a PSF datamodel defined. For any other
        # exposure type, try to use the model MIRI LRS uses to open the input model
        try:
            psf_model = SpecPsfModel(psf_refname)
        except (ValueError, AttributeError):
            raise NotImplementedError(
                f"PSF file for EXP_TYPE {exp_type} could not be read as SpecPsfModel."
            ) from None
    return psf_model


def _normalize_profile(profile, dispaxis):
    """Normalize a spatial profile along the cross-dispersion axis."""
    if dispaxis == HORIZONTAL:
        psum = np.nansum(profile, axis=0)
        nz = psum != 0
        profile[:, nz] = profile[:, nz] / psum[nz]
        profile[:, ~nz] = 0.0
    else:
        psum = np.nansum(profile, axis=1)
        nz = psum != 0
        profile[nz, :] = profile[nz, :] / psum[nz, None]
        profile[~nz, :] = 0.0
    profile[~np.isfinite(profile)] = 0.0


def _make_cutout_profile(
    xidx, yidx, psf_subpix, psf_data, dispaxis, extra_shift=0.0, nod_offset=None
):
    """
    Make a spatial profile corresponding to the data cutout.

    Input index values should already contain the shift to the trace location
    in the cross-dispersion direction and any wavelength shifts necessary
    in the dispersion direction.

    Parameters
    ----------
    xidx : ndarray of float
        Index array for x values.
    yidx : ndarray of float
        Index array for y values.
    psf_subpix : float
        Scaling factor for pixel size in the PSF data.
    psf_data : ndarray of float
        2D PSF model.
    dispaxis : int
        Dispersion axis.
    extra_shift : float, optional
        An extra shift for the primary trace location, to be added to the
        cross-dispersion indices.
    nod_offset : float, optional
        If not None, a negative trace is added to the spatial profile,
        with a cross-dispersion shift of `nod_offset`.

    Returns
    -------
    profiles : list of ndarray of float
        2D spatial profiles containing the primary trace and, optionally,
        a negative trace for a nod pair.  The profiles are normalized along
        the cross-dispersion axis.
    """
    # Add an extra spatial shift to the primary trace
    if dispaxis == HORIZONTAL:
        xmap = xidx
        ymap = yidx + extra_shift * psf_subpix
    else:
        xmap = xidx + extra_shift * psf_subpix
        ymap = yidx
    sprofile = ndimage.map_coordinates(psf_data, [ymap, xmap], order=1)
    _normalize_profile(sprofile, dispaxis)

    if nod_offset is None:
        return [sprofile]

    # Make an additional profile for the negative nod if desired
    if dispaxis == HORIZONTAL:
        ymap += psf_subpix * nod_offset
    else:
        xmap += psf_subpix * nod_offset

    nod_profile = ndimage.map_coordinates(psf_data, [ymap, xmap], order=1)
    _normalize_profile(nod_profile, dispaxis)

    return [sprofile, nod_profile * -1]


def _profile_residual(
    shifts_to_optimize, cutout, cutout_var, xidx, yidx, psf_subpix, psf_data, dispaxis, fit_bkg=True
):
    """
    Residual function to minimize for optimizing trace locations.

    Call `_make_cutout_profile` to generate a profile from input parameters.
    Call `extract1d` to generate a scene model from the data and the new profile.
    Compute a residual value from the sum of (model - cutout) ** 2 / cutout_var.

    Parameters
    ----------
    shifts_to_optimize : list of float
        The first value is used as the `extra_shift` parameter to
        `_make_cutout_profile`.  If two are provided, the second value is
        used as the `nod_offset` parameter to `_make_cutout_profile`.
        If only one value is provided, no nod offset is applied.
    cutout : ndarray
        Input data array, trimmed to the bounding box.
    cutout_var : ndarray
        Input variance, matching the data array, used for weighting the
        output residuals.
    xidx : ndarray of float
        Index array for x values.
    yidx : ndarray of float
        Index array for y values.
    psf_subpix : float
        Scaling factor for pixel size in the PSF data.
    psf_data : ndarray of float
        2D PSF model.
    dispaxis : int
        Dispersion axis.
    fit_bkg : bool, optional
        If True, background subtraction is performed during extraction.

    Returns
    -------
    float
        The residual value to minimize.
    """
    if len(shifts_to_optimize) > 1:
        nod_offset = shifts_to_optimize[1]
    else:
        nod_offset = None
    sprofiles = _make_cutout_profile(
        xidx,
        yidx,
        psf_subpix,
        psf_data,
        dispaxis,
        extra_shift=shifts_to_optimize[0],
        nod_offset=nod_offset,
    )
    extract_kwargs = {
        "extraction_type": "optimal",
        "fit_bkg": fit_bkg,
        "bkg_fit_type": "poly",
        "bkg_order": 0,
    }
    if dispaxis == HORIZONTAL:
        empty_var = np.zeros_like(cutout)
        result = extract1d(cutout, sprofiles, cutout_var, empty_var, empty_var, **extract_kwargs)
        model = result[-1]
    else:
        sprofiles = [profile.T for profile in sprofiles]
        empty_var = np.zeros_like(cutout.T)
        result = extract1d(
            cutout.T, sprofiles, cutout_var.T, empty_var, empty_var, **extract_kwargs
        )
        model = result[-1].T
    return np.nansum((model - cutout) ** 2 / cutout_var)


[docs] def psf_profile( input_model, trace, wl_array, psf_ref_name, optimize_shifts=True, model_nod_pair=True ): """ Create a spatial profile from a PSF reference. Provides PSF-based profiles for point sources in slit-like data containing one positive trace and, optionally, one negative trace resulting from nod subtraction. The location of the positive trace should be provided in the `trace` input parameter; the negative trace location will be guessed from the input metadata. If a negative trace is modeled, it is recommended that `optimize_shifts` also be set to True, to improve the initial guess for the trace location. Parameters ---------- input_model : data model This can be either the input science file or one SlitModel out of a list of slits. trace : ndarray or None Array of source cross-dispersion position values, one for each dispersion element in the input model data. If None, the source is assumed to be at the center of the slit. wl_array : ndarray Array of wavelength values, matching the input model data shape, for each pixel in the array. psf_ref_name : str PSF reference filename. optimize_shifts : bool, optional If True, the spatial location of the trace will be optimized by minimizing the residuals in a scene model compared to the data in the first integration of `input_model`. model_nod_pair : bool, optional If True, and if background subtraction has taken place, a negative PSF will be modeled at the mirrored spatial location of the positive trace. Returns ------- profile : ndarray Spatial profile matching the input data. lower_limit : int Lower limit of the aperture in the cross-dispersion direction. For PSF profiles, this is always set to the lower edge of the bounding box, since the full array may have non-zero weight. upper_limit : int Upper limit of the aperture in the cross-dispersion direction. For PSF profiles, this is always set to the upper edge of the bounding box, since the full array may have non-zero weight. """ # Read in reference files exp_type = input_model.meta.exposure.type psf_model = open_psf(psf_ref_name, exp_type) # Get the data cutout data_shape = input_model.data.shape[-2:] dispaxis = input_model.meta.wcsinfo.dispersion_direction wcs = input_model.meta.wcs bbox = wcs.bounding_box y0 = int(np.ceil(bbox[1][0])) y1 = int(np.ceil(bbox[1][1])) x0 = int(np.ceil(bbox[0][0])) x1 = int(np.ceil(bbox[0][1])) if input_model.data.ndim == 3: # use the first integration only cutout = input_model.data[0, y0:y1, x0:x1] cutout_var = input_model.var_rnoise[0, y0:y1, x0:x1] else: cutout = input_model.data[y0:y1, x0:x1] cutout_var = input_model.var_rnoise[y0:y1, x0:x1] cutout_wl = wl_array[y0:y1, x0:x1] # Get the nominal center of the cutout middle_disp, middle_xdisp, middle_wl = middle_from_wcs(wcs, bbox, dispaxis) # Get the effective index into the 1D PSF wavelengths from the data wavelengths psf_wave = psf_model.wave sort_idx = np.argsort(psf_wave) valid_wave = np.isfinite(psf_wave[sort_idx]) wave_idx = np.interp( cutout_wl, psf_wave[sort_idx][valid_wave], sort_idx[valid_wave], left=np.nan, right=np.nan ) if trace is None: # Don't try to model a negative pair if we don't have a trace to start if model_nod_pair: log.warning("Cannot model a negative nod without position information") model_nod_pair = False # Set the location to the middle of the cross-dispersion # all the way across the array location = middle_xdisp if dispaxis == HORIZONTAL: trace = trace_from_wcs( exp_type, data_shape, bbox, wcs, middle_disp, middle_xdisp, dispaxis ) else: trace = trace_from_wcs( exp_type, data_shape, bbox, wcs, middle_xdisp, middle_disp, dispaxis ) else: # Nominal location from the middle dispersion point location = trace[int(np.round(middle_disp))] # Trim the trace to the data cutout if dispaxis == HORIZONTAL: trace = trace[x0:x1] else: trace = trace[y0:y1] # Check if we need to add a negative nod pair trace nod_offset = None if model_nod_pair: nod_subtracted = str(input_model.meta.cal_step.bkg_subtract) == "COMPLETE" pattype_ok = str(input_model.meta.dither.primary_type) in NOD_PAIR_PATTERN if not nod_subtracted: log.info("Input data was not nod-subtracted. A negative trace will not be modeled.") elif not pattype_ok: log.info("Input data was not a two-point nod. A negative trace will not be modeled.") else: nod_center = nod_pair_location(input_model, middle_wl) if np.isnan(nod_center) or (np.abs(location - nod_center) < 2): log.warning("Nod center could not be estimated from the WCS.") log.warning("The negative nod will not be modeled.") else: if not optimize_shifts: log.warning("Negative nod locations are currently approximations only.") log.warning( "PSF location optimization is recommended when negative nods are modeled." ) nod_offset = location - nod_center # Get an index grid for the data cutout cutout_shape = cutout.shape _y, _x = np.mgrid[: cutout_shape[0], : cutout_shape[1]] # Scale the trace location to the subsampled psf and # add the wavelength and spatial shifts to the coordinates to map to psf_subpix = psf_model.meta.psf.subpix psf_location = trace - bbox[0][0] if dispaxis == HORIZONTAL: psf_shift = psf_model.meta.psf.center_row - (psf_location * psf_subpix) xidx = wave_idx yidx = _y * psf_subpix + psf_shift else: psf_shift = psf_model.meta.psf.center_col - (psf_location * psf_subpix) xidx = _x * psf_subpix + psf_shift[:, None] yidx = wave_idx # If desired, add additional spatial shifts to the starting locations of # the primary trace (and negative nod pair trace if necessary) if optimize_shifts: log.info("Optimizing trace locations") if nod_offset is None: (extra_shift,) = optimize.minimize( _profile_residual, [0.0], (cutout, cutout_var, xidx, yidx, psf_subpix, psf_model.data, dispaxis), method="Nelder-Mead", ).x else: extra_shift, nod_offset = optimize.minimize( _profile_residual, [0.0, nod_offset], (cutout, cutout_var, xidx, yidx, psf_subpix, psf_model.data, dispaxis), method="Nelder-Mead", ).x location -= extra_shift else: extra_shift = 0.0 log.info(f"Centering profile on spectrum at {location:.2f}, wavelength {middle_wl:.2f}") if nod_offset is not None: log.info( f"Also modeling a negative trace at {location - nod_offset:.2f} " f"(offset: {nod_offset:.2f})" ) # Make a spatial profile from the shifted PSF data sprofiles = _make_cutout_profile( xidx, yidx, psf_subpix, psf_model.data, dispaxis, extra_shift=extra_shift, nod_offset=nod_offset, ) # Make the output profile, matching the input data output_y = _y + y0 output_x = _x + x0 valid = (output_y >= 0) & (output_y < y1) & (output_x >= 0) & (output_x < x1) profiles = [] for sprofile in sprofiles: profile = np.full(data_shape, 0.0) profile[output_y[valid], output_x[valid]] = sprofile[valid] profiles.append(profile) if dispaxis == HORIZONTAL: limits = (y0, y1) else: limits = (x0, x1) return profiles, *limits