Source code for jwst.extract_1d.source_location

import logging

import numpy as np
from gwcs.wcstools import grid_from_bounding_box
from scipy.interpolate import interp1d
from stdatamodels.jwst.transforms.models import IdealToV2V3

from jwst.assign_wcs.util import wcs_bbox_from_shape

__all__ = ["middle_from_wcs", "location_from_wcs", "trace_from_wcs", "nod_pair_location"]

HORIZONTAL = 1
"""Horizontal dispersion axis."""
VERTICAL = 2
"""Vertical dispersion axis."""

log = logging.getLogger(__name__)


[docs] def middle_from_wcs(wcs, bounding_box, dispaxis): """ Calculate the effective middle of the spectral region. Parameters ---------- wcs : `~gwcs.wcs.WCS` WCS for the input data model, containing detector to wavelength transforms. bounding_box : tuple A pair of tuples, each consisting of two numbers. Represents the range of useful pixel values in both dimensions, ((xmin, xmax), (ymin, ymax)). dispaxis : int Dispersion axis. Returns ------- middle_disp : float Middle pixel in the dispersion axis. middle_xdisp : float Middle pixel in the cross-dispersion axis. middle_wavelength : float Wavelength at the middle pixel. """ if dispaxis == HORIZONTAL: # Width (height) in the cross-dispersion direction, from the start of # the 2-D cutout (or of the full image) to the upper limit of the bounding box. xd_width = int(round(bounding_box[1][1])) # must be an int # Middle of the bounding_box in the dispersion direction. middle_disp = (bounding_box[0][0] + bounding_box[0][1]) / 2.0 x = np.full(xd_width, middle_disp) # 1-D vector of cross-dispersion (y) pixel indices y = np.arange(xd_width, dtype=np.float64) else: # Cross-dispersion total width of bounding box; must be an int xd_width = int(round(bounding_box[0][1])) # Middle of the bounding_box in the dispersion direction. middle_disp = (bounding_box[1][0] + bounding_box[1][1]) / 2.0 y = np.full(xd_width, middle_disp) # 1-D vector of cross-dispersion (x) pixel indices x = np.arange(xd_width, dtype=np.float64) # Get all the wavelengths at the middle dispersion element _, _, center_wavelengths = wcs(x, y) sort_idx = np.argsort(center_wavelengths) valid = np.isfinite(center_wavelengths[sort_idx]) # Average to get the middle wavelength middle_wavelength = np.nanmean(center_wavelengths) # Find the effective index in cross-dispersion coordinates for the # averaged wavelength to get the cross-dispersion center if dispaxis == HORIZONTAL: if np.allclose(center_wavelengths, middle_wavelength): middle_xdisp = np.mean(y) else: middle_xdisp = np.interp( middle_wavelength, center_wavelengths[sort_idx][valid], y[sort_idx[valid]] ) else: if np.allclose(center_wavelengths, middle_wavelength): middle_xdisp = np.mean(x) else: middle_xdisp = np.interp( middle_wavelength, center_wavelengths[sort_idx][valid], x[sort_idx[valid]] ) return middle_disp, middle_xdisp, middle_wavelength
[docs] def location_from_wcs(input_model, slit, make_trace=True): """ Get the cross-dispersion location of the spectrum, based on the WCS. None values will be returned if there was insufficient information available, e.g. if the wavelength attribute or wcs function is not defined. Parameters ---------- input_model : DataModel The input science model containing metadata information. slit : DataModel or None One slit from a MultiSlitModel (or similar), or None. The WCS and target coordinates will be retrieved from `slit` unless `slit` is None. In that case, they will be retrieved from `input_model`. make_trace : bool, optional If True, the source position will be calculated for each dispersion element and returned in `trace`. If False, None is returned. Returns ------- middle : int or None Pixel coordinate in the dispersion direction within the 2-D cutout (or the entire input image) at the middle of the WCS bounding box. This is the point at which to determine the nominal extraction location, in case it varies along the spectrum. The offset will then be the difference between `location` (below) and the nominal location. middle_wl : float or None The wavelength at pixel `middle`. location : float or None Pixel coordinate in the cross-dispersion direction within the spectral image that is at the planned target location. The spectral extraction region should be centered here. trace : ndarray or None An array of source positions, one per dispersion element, corresponding to the location at each point in the wavelength array. If the input data is resampled, the trace corresponds directly to the location. If the trace could not be generated, or `make_trace` is False, None is returned. """ if slit is not None: shape = slit.data.shape[-2:] wcs = slit.meta.wcs dispaxis = slit.meta.wcsinfo.dispersion_direction else: shape = input_model.data.shape[-2:] wcs = input_model.meta.wcs dispaxis = input_model.meta.wcsinfo.dispersion_direction bb = wcs.bounding_box # ((x0, x1), (y0, y1)) if bb is None: bb = wcs_bbox_from_shape(shape) if dispaxis == HORIZONTAL: lower = bb[1][0] upper = bb[1][1] else: lower = bb[0][0] upper = bb[0][1] # Get the wavelengths for the valid data in the sky transform, # average to get the middle wavelength middle, _, middle_wl = middle_from_wcs(wcs, bb, dispaxis) middle = int(np.round(middle)) exp_type = input_model.meta.exposure.type trace = None if exp_type in ["NRS_FIXEDSLIT", "NRS_MSASPEC", "NRS_BRIGHTOBJ"]: log.info("Using source_xpos and source_ypos to center extraction.") if slit is None: xpos = input_model.source_xpos ypos = input_model.source_ypos else: xpos = slit.source_xpos ypos = slit.source_ypos slit2det = wcs.get_transform("slit_frame", "detector") if "gwa" in wcs.available_frames: # Input is not resampled, wavelengths need to be meters _, location = slit2det(xpos, ypos, middle_wl * 1e-6) else: _, location = slit2det(xpos, ypos, middle_wl) if ~np.isnan(location) and make_trace: trace = _nirspec_trace_from_wcs(shape, bb, wcs, xpos, ypos) elif exp_type == "MIR_LRS-FIXEDSLIT": log.info("Using dithered_ra and dithered_dec to center extraction.") try: if slit is None: dithra = input_model.meta.dither.dithered_ra dithdec = input_model.meta.dither.dithered_dec else: dithra = slit.meta.dither.dithered_ra dithdec = slit.meta.dither.dithered_dec location, _ = wcs.backward_transform(dithra, dithdec, middle_wl) except (AttributeError, TypeError): log.warning("Dithered pointing location not found in wcsinfo.") return None, None, None, None if ~np.isnan(location) and make_trace: trace = _miri_trace_from_wcs(shape, bb, wcs, dithra, dithdec) else: log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") return None, None, None, None if np.isnan(location): log.warning("Source position could not be determined from WCS.") return None, None, None, None # If the target is at the edge of the image or at the edge of the # non-NaN area, we can't use the WCS to find the # location of the target spectrum. if location < lower or location > upper: log.warning( f"WCS implies the target is at {location:.2f}, which is outside the bounding box," ) log.warning("so we can't get spectrum location using the WCS") return None, None, None, None return middle, middle_wl, location, trace
def _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos): """ Calculate NIRSpec source trace from WCS. The source trace is calculated by projecting the recorded source positions source_xpos/ypos from the NIRSpec "slit_frame" onto detector pixels. Parameters ---------- shape : tuple of int 2D shape for the full input data array, (ny, nx). bounding_box : tuple A pair of tuples, each consisting of two numbers. Represents the range of useful pixel values in both dimensions, ((xmin, xmax), (ymin, ymax)). wcs_ref : `~gwcs.wcs.WCS` WCS for the input data model, containing slit and detector transforms. source_xpos : float Slit position, in the x direction, for the target. source_ypos : float Slit position, in the y direction, for the target. Returns ------- trace : ndarray of float Fractional pixel positions in the y (cross-dispersion direction) of the trace for each x (dispersion direction) pixel. """ x, y = grid_from_bounding_box(bounding_box) nx = int(bounding_box[0][1] - bounding_box[0][0]) # Calculate the wavelengths in the slit frame because they are in # meters for cal files and um for s2d files d2s = wcs_ref.get_transform("detector", "slit_frame") _, _, slit_wavelength = d2s(x, y) # Make an initial array of wavelengths that will cover the wavelength range of the data wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), nx) # Get arrays of the source position in the slit pos_x = np.full(nx, source_xpos) pos_y = np.full(nx, source_ypos) # Grab the wcs transform between the slit frame where we know the # source position and the detector frame s2d = wcs_ref.get_transform("slit_frame", "detector") # Calculate the expected center of the source trace trace_x, trace_y = s2d(pos_x, pos_y, wave_vals) # Interpolate the trace to a regular pixel grid in the dispersion # direction interp_trace = interp1d(trace_x, trace_y, fill_value="extrapolate") # Get the trace position for each dispersion element trace = interp_trace(np.arange(nx)) # Place the trace in the full array full_trace = np.full(shape[1], np.nan) x0 = int(np.ceil(bounding_box[0][0])) full_trace[x0 : x0 + nx] = trace return full_trace def _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec): """ Calculate MIRI LRS fixed slit source trace from WCS. The source trace is calculated by projecting the recorded source positions dithered_ra/dec from the world frame onto detector pixels. Parameters ---------- shape : tuple of int 2D shape for the full input data array, (ny, nx). bounding_box : tuple A pair of tuples, each consisting of two numbers. Represents the range of useful pixel values in both dimensions, ((xmin, xmax), (ymin, ymax)). wcs_ref : `~gwcs.wcs.WCS` WCS for the input data model, containing sky and detector transforms, forward and backward. source_ra : float RA coordinate for the target. source_dec : float Dec coordinate for the target. Returns ------- trace : ndarray of float Fractional pixel positions in the x (cross-dispersion direction) of the trace for each y (dispersion direction) pixel. """ x, y = grid_from_bounding_box(bounding_box) ny = int(bounding_box[1][1] - bounding_box[1][0]) # Calculate the wavelengths for the full array _, _, slit_wavelength = wcs_ref(x, y) # Make an initial array of wavelengths that will cover the wavelength range of the data wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), ny) # Get arrays of the source position pos_ra = np.full(ny, source_ra) pos_dec = np.full(ny, source_dec) # Calculate the expected center of the source trace trace_x, trace_y = wcs_ref.backward_transform(pos_ra, pos_dec, wave_vals) # Interpolate the trace to a regular pixel grid in the dispersion # direction interp_trace = interp1d(trace_y, trace_x, fill_value="extrapolate") # Get the trace position for each dispersion element within the bounding box trace = interp_trace(np.arange(ny)) # Place the trace in the full array full_trace = np.full(shape[0], np.nan) y0 = int(np.ceil(bounding_box[1][0])) full_trace[y0 : y0 + ny] = trace return full_trace
[docs] def trace_from_wcs(exp_type, shape, bounding_box, wcs_ref, source_x, source_y, dispaxis): """ Calculate a source trace from WCS. The source trace is calculated by projecting a fixed source positions onto detector pixels, to get a source location at each dispersion element. For MIRI LRS fixed slit and NIRSpec modes, this will be a curved trace, using the sky or slit frame as appropriate. For all other modes, a flat trace is returned, containing the cross-dispersion position at all dispersion elements. Parameters ---------- exp_type : str Exposure type for the input data. shape : tuple of int 2D shape for the full input data array, (ny, nx). bounding_box : tuple A pair of tuples, each consisting of two numbers. Represents the range of useful pixel values in both dimensions, ((xmin, xmax), (ymin, ymax)). wcs_ref : `~gwcs.wcs.WCS` WCS for the input data model, containing sky and detector transforms, forward and backward. source_x : float X pixel coordinate for the target. source_y : float Y pixel coordinate for the target. dispaxis : int Dispersion axis. Returns ------- trace : ndarray of float Pixel positions in the cross-dispersion direction of the trace for each dispersion pixel. """ if exp_type == "MIR_LRS-FIXEDSLIT": source_ra, source_dec, _ = wcs_ref(source_x, source_y) trace = _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec) elif exp_type.startswith("NRS"): d2s = wcs_ref.get_transform("detector", "slit_frame") source_xpos, source_ypos, _ = d2s(source_x, source_y) trace = _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos) else: # Flat trace containing the cross-dispersion position at every element if dispaxis == HORIZONTAL: trace = np.full(shape[1], np.nan) x0 = int(np.ceil(bounding_box[0][0])) nx = int(bounding_box[0][1] - bounding_box[0][0]) trace[x0 : x0 + nx] = source_y else: trace = np.full(shape[0], np.nan) y0 = int(np.ceil(bounding_box[1][0])) ny = int(bounding_box[1][1] - bounding_box[1][0]) trace[y0 : y0 + ny] = source_x return trace
def _nod_pair_from_dither(input_model, middle_wl, dispaxis): """ Estimate a nod pair location from the dither offsets. Expected location is at the opposite spatial offset from the input model. Requires 'v2v3' transform in the WCS, so is only available for unresampled data. Parameters ---------- input_model : DataModel Model containing WCS and dither data. middle_wl : float Wavelength at the middle of the array. dispaxis : int Dispersion axis. Returns ------- nod_location : float The expected location of the negative trace, in the cross-dispersion direction, at the middle wavelength. """ if "v2v3" not in input_model.meta.wcs.available_frames: return np.nan idltov23 = IdealToV2V3( input_model.meta.wcsinfo.v3yangle, input_model.meta.wcsinfo.v2_ref, input_model.meta.wcsinfo.v3_ref, input_model.meta.wcsinfo.vparity, ) if dispaxis == HORIZONTAL: x_offset = input_model.meta.dither.x_offset y_offset = -input_model.meta.dither.y_offset else: x_offset = -input_model.meta.dither.x_offset y_offset = input_model.meta.dither.y_offset dithered_v2, dithered_v3 = idltov23(x_offset, y_offset) # v23toworld requires a wavelength along with v2, v3, but value does not affect return v23toworld = input_model.meta.wcs.get_transform("v2v3", "world") dithered_ra, dithered_dec, _ = v23toworld(dithered_v2, dithered_v3, 0.0) x, y = input_model.meta.wcs.backward_transform(dithered_ra, dithered_dec, middle_wl) if dispaxis == HORIZONTAL: return y else: return x def _nod_pair_from_slitpos(input_model, middle_wl): """ Estimate a nod pair location from the source slit position. Expected location is at the opposite spatial position from the input model. Requires 'slit_frame' transform in the WCS. Implemented only for NIRSpec, assuming horizontal dispersion axis. Parameters ---------- input_model : DataModel Model containing WCS and dither data. middle_wl : float Wavelength at the middle of the array. Returns ------- nod_location : float The expected location of the negative trace, in the cross-dispersion direction, at the middle wavelength. """ xpos = input_model.source_xpos ypos = -input_model.source_ypos wcs = input_model.meta.wcs slit2det = wcs.get_transform("slit_frame", "detector") if "gwa" in wcs.available_frames: # Input is not resampled, wavelengths need to be meters _, location = slit2det(xpos, ypos, middle_wl * 1e-6) else: _, location = slit2det(xpos, ypos, middle_wl) return location
[docs] def nod_pair_location(input_model, middle_wl): """ Estimate a nod pair location from the WCS. For MIRI, it will guess the location from the dither offsets. For NIRSpec, it will guess from the slit position. For anything else, or if the estimate fails, it will return NaN for the location. Parameters ---------- input_model : DataModel Model containing WCS and dither data. middle_wl : float Wavelength at the middle of the array. Returns ------- nod_location : float The expected location of the negative trace, in the cross-dispersion direction, at the middle wavelength. """ exp_type = input_model.meta.exposure.type nod_center = np.nan if exp_type == "MIR_LRS-FIXEDSLIT": dispaxis = input_model.meta.wcsinfo.dispersion_direction nod_center = _nod_pair_from_dither(input_model, middle_wl, dispaxis) elif exp_type.startswith("NRS"): nod_center = _nod_pair_from_slitpos(input_model, middle_wl) return nod_center