Source code for jwst.resample.resample_spec

import logging
import math
import warnings

import numpy as np
from astropy import coordinates as coord
from astropy import units as u
from astropy.modeling.fitting import LinearLSQFitter
from astropy.modeling.models import (
    Const1D,
    Linear1D,
    Mapping,
    Pix2Sky_TAN,
    RotateNative2Celestial,
    Tabular1D,
)
from astropy.stats import sigma_clip
from astropy.utils.exceptions import AstropyUserWarning
from gwcs import WCS, wcstools
from gwcs import coordinate_frames as cf
from stdatamodels.jwst import datamodels

from jwst.assign_wcs.util import compute_scale, wcs_bbox_from_shape, wrap_ra
from jwst.datamodels import ModelLibrary
from jwst.resample import resample_utils
from jwst.resample.resample import ResampleImage

log = logging.getLogger(__name__)


__all__ = ["ResampleSpec"]


[docs] class ResampleSpec(ResampleImage): """ Resample spectral data. Notes ----- This routine performs the following operations:: 1. Extracts parameter settings from input model, such as pixfrac, weight type, exposure time (if relevant), and kernel, and merges them with any user-provided values. 2. Creates output WCS based on input images and define mapping function between all input arrays and the output array. 3. Updates output data model with output arrays from drizzle, including a record of metadata from all input models. """ def __init__(self, input_models, good_bits=0, output_wcs=None, wcs_pars=None, **kwargs): """ Initialize the ResampleSpec object. Parameters ---------- input_models : list List of data models, one for each input image good_bits : int Bit values that should be considered good when creating a mask output_wcs : dict Output WCS parameters wcs_pars : dict Additional parameters for WCS **kwargs : dict Additional parameters to be passed into ``ResampleImage.__init__()``. See the docstring of that method for more details. """ shape = None pixel_scale = None pixel_area = None pixel_scale_ratio = 1.0 if isinstance(output_wcs, dict): output_wcs_dict = {k: v for k, v in output_wcs.items() if k != "wcs"} output_wcs = output_wcs["wcs"] pixel_scale = output_wcs_dict.get("pixel_scale") pixel_area = output_wcs_dict.get("pixel_area") pixel_scale_ratio = output_wcs_dict.get("pixel_scale_ratio", 1.0) shape = output_wcs.array_shape else: output_wcs_dict = None if output_wcs is None and wcs_pars is not None: shape = wcs_pars.get("output_shape") pixel_scale = wcs_pars.get("pixel_scale") pixel_scale_ratio = wcs_pars.get("pixel_scale_ratio", 1.0) if pixel_scale is None and pixel_area is not None: pixel_scale = math.sqrt(pixel_area) elif pixel_scale is not None and pixel_area is None: pixel_area = pixel_scale**2 # Get an average input pixel scale for parameter calculations disp_axis = input_models[0].meta.wcsinfo.dispersion_direction input_pixscale0 = 3600.0 * compute_spectral_pixel_scale( input_models[0].meta.wcs, disp_axis=disp_axis ) if np.isnan(input_pixscale0): log.warning("Input pixel scale could not be determined.") if pixel_scale is not None: log.warning( "Output pixel scale setting is not supported without an " "input pixel scale. Setting pixel_scale=None." ) pixel_scale = None pixel_area = None nominal_area = input_models[0].meta.photometry.pixelarea_steradians if nominal_area is None: log.warning("Nominal pixel area not set in input data.") log.warning( "Setting output pixel scale is not supported without an " "input pixel scale. Setting pixel_scale=None." ) pixel_scale = None pixel_area = None if output_wcs: # Use user-supplied reference WCS for the resampled image: if pixel_area is None: if nominal_area is None: log.warning("Unable to compute output pixel area from 'output_wcs'.") output_pix_area = None else: # Compare input and output spatial scale to update nominal area output_pscale = 3600.0 * compute_spectral_pixel_scale( output_wcs, disp_axis=disp_axis ) if np.isnan(output_pscale) or np.isnan(input_pixscale0): log.warning("Output pixel scale could not be determined.") output_pix_area = None else: log.debug( f"Setting output pixel area from the approximate " f"output spatial scale: {output_pscale}" ) output_pix_area = output_pscale * nominal_area / input_pixscale0 else: log.debug(f"Using output pixel area: {pixel_area}") output_pix_area = pixel_area # Set the pixel scale ratio for scaling reasons if output_pix_area is None: pixel_scale_ratio = 1.0 else: pixel_scale_ratio = nominal_area / output_pix_area # Set the output shape if specified if shape is not None: output_wcs.array_shape = shape else: if pixel_scale is not None and nominal_area is not None: log.info(f"Specified output pixel scale: {pixel_scale} arcsec.") # Set the pscale ratio from the input pixel scale # (pixel scale ratio is output / input) if pixel_scale_ratio != 1.0: log.warning( "Ignoring input pixel_scale_ratio in favor of explicit pixel_scale." ) pixel_scale_ratio = input_pixscale0 / pixel_scale log.info(f"Computed output pixel scale ratio: {pixel_scale_ratio:.5g}") # Define output WCS based on all inputs, including a reference WCS. # These functions internally use pixel_scale_ratio to accommodate # user settings. # Any other customizations (crpix, crval, rotation) are ignored. if resample_utils.is_sky_like(input_models[0].meta.wcs.output_frame): if input_models[0].meta.instrument.name != "NIRSPEC": output_wcs = self.build_interpolated_output_wcs( input_models, pixel_scale_ratio=pixel_scale_ratio ) else: output_wcs = self.build_nirspec_output_wcs( input_models, good_bits=good_bits, pixel_scale_ratio=pixel_scale_ratio ) else: output_wcs = self.build_nirspec_lamp_output_wcs( input_models, pixel_scale_ratio=pixel_scale_ratio ) # Use the nominal output pixel area in sr if available, # scaling for user-set pixel_scale ratio if needed. if nominal_area is not None: # Note that there is only one spatial dimension so the # pixel_scale_ratio is not squared. output_pix_area = nominal_area / pixel_scale_ratio else: output_pix_area = None self._spec_output_pix_area = output_pix_area if pixel_scale is None: log.info(f"Specified output pixel scale ratio: {pixel_scale_ratio}.") pixel_scale = 3600.0 * compute_spectral_pixel_scale(output_wcs, disp_axis=disp_axis) log.info(f"Computed output pixel scale: {pixel_scale:.5g} arcsec.") if output_wcs_dict is None: output_wcs_dict = {} output_wcs_dict["wcs"] = output_wcs output_wcs_dict["pixel_scale"] = pixel_scale output_wcs_dict["pixel_scale_ratio"] = pixel_scale_ratio library = ModelLibrary(input_models, on_disk=False) super().__init__( library, good_bits=good_bits, output_wcs=output_wcs_dict, wcs_pars=None, **kwargs ) self.intermediate_suffix = "outlier_s2d"
[docs] def create_output_jwst_model(self, ref_input_model=None): """ Create a new blank model and update its meta with info from ``ref_input_model``. Parameters ---------- ref_input_model : `~jwst.datamodels.JwstDataModel`, optional The reference input model from which to copy meta data. Returns ------- SlitModel A new blank model with updated meta data. """ output_model = datamodels.SlitModel(None) # update meta data and wcs if ref_input_model is not None: output_model.update(ref_input_model) output_model.meta.wcs = self.output_wcs return output_model
[docs] def update_output_model(self, model, info_dict): """ Add spectroscopy-specific meta information to the output model. Parameters ---------- model : SlitModel The output model to be updated. info_dict : dict A dictionary containing information about the resampling process. """ super().update_output_model(model, info_dict) if self._spec_output_pix_area is None: model.meta.photometry.pixelarea_steradians = None model.meta.photometry.pixelarea_arcsecsq = None else: model.meta.photometry.pixelarea_steradians = self._spec_output_pix_area model.meta.photometry.pixelarea_arcsecsq = ( self._spec_output_pix_area * np.rad2deg(3600) ** 2 )
# TODO: this is helpful info that should be stored in products. # Not storing this at this time in order to reduce the number of # failures in the regression tests. # model.meta.resample.pixel_scale_ratio # model.meta.resample.pixfrac # model.meta.resample.weight_type # model.meta.resample.pointings # model.meta.cal_step.resample
[docs] def build_nirspec_output_wcs( self, input_models, refmodel=None, good_bits=None, pixel_scale_ratio=1.0 ): """ Create a spatial/spectral WCS covering the footprint of the input. Creates the output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. For NIRSpec, the output WCS must also provide slit coordinates to support source location in the spectral extraction step. To do so, this step creates a lookup table for virtual slit coordinates, corresponding to the slit y-position at the center of the array in the input reference model. Slit x-position is set to zero for all pixels. Frames available in the output WCS are: - `detector`: image x, y - `slit_frame`: slit x, slit y, wavelength - `world`: RA, Dec, wavelength Parameters ---------- refmodel : `~jwst.datamodels.JwstDataModel`, optional The reference input image from which the fiducial WCS is created. If not specified, the first image in input_models. If the first model is empty (all-NaN or all-zero), the first non-empty model is used. Returns ------- output_wcs : `~gwcs.wcs.WCS` A GWCS object defining the output frame WCS. """ all_wcs = [m.meta.wcs for m in input_models if m is not refmodel] if refmodel: all_wcs.insert(0, refmodel.meta.wcs) else: # Use the first model with a reasonable amount of good data # as the reference model for model in input_models: dq_mask = resample_utils.build_mask(model.dq, good_bits) good = np.isfinite(model.data) & (model.data != 0) & dq_mask if np.sum(good) > 100 and refmodel is None: refmodel = model break # If no good data was found, use the first model. if refmodel is None: refmodel = input_models[0] # Make a copy of the data array for internal manipulation refmodel_data = refmodel.data.copy() # Renormalize to the minimum value, for best results when # computing the weighted mean below with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN") refmodel_data -= np.nanmin(refmodel_data) # Save the wcs of the reference model refwcs = refmodel.meta.wcs # Set up the transforms that are needed s2d = refwcs.get_transform("slit_frame", "detector") d2s = refwcs.get_transform("detector", "slit_frame") if "moving_target" in refwcs.available_frames: s2w = refwcs.get_transform("slit_frame", "moving_target") w2s = refwcs.get_transform("moving_target", "slit_frame") else: s2w = refwcs.get_transform("slit_frame", "world") w2s = refwcs.get_transform("world", "slit_frame") # Estimate position of the target without relying on the meta.target: # compute the mean spatial and wavelength coords weighted # by the spectral intensity bbox = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) _, s, lam = np.array(d2s(*grid)) # Find invalid values good = np.isfinite(s) & np.isfinite(lam) & np.isfinite(refmodel_data) refmodel_data[~good] = np.nan # Reject the worst outliers in the data with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=AstropyUserWarning, message=".*automatically clipped.*" ) weights = sigma_clip(refmodel_data, masked=True, sigma=100.0) weights = np.ma.filled(weights, fill_value=0.0) if not np.all(weights == 0.0): wmean_s = np.average(s[good], weights=weights[good]) else: wmean_s = np.nanmean(s) wmean_l = np.nanmean(lam) # Transform the weighted means into target RA/Dec # (at the center of the slit in x) targ_ra, targ_dec, _ = s2w(0, wmean_s, wmean_l) sx, sy = s2d(0, wmean_s, wmean_l) log.debug(f"Fiducial RA, Dec, wavelength: {targ_ra}, {targ_dec}, {wmean_l}") log.debug(f"Index at fiducial center: x={sx}, y={sy}") # Estimate spatial sampling from the reference model # at the center of the array lam_center_idx = int(np.mean(bbox, axis=1)[0]) log.debug(f"Center of dispersion axis: {lam_center_idx}") grid_center = grid[0][:, lam_center_idx], grid[1][:, lam_center_idx] ra_ref, dec_ref, _ = np.array(refwcs(*grid_center)) # Convert ra and dec to tangent projection tan = Pix2Sky_TAN() native2celestial = RotateNative2Celestial(targ_ra, targ_dec, 180) undist2sky = tan | native2celestial x_tan, y_tan = undist2sky.inverse(ra_ref, dec_ref) is_nan = np.isnan(x_tan) | np.isnan(y_tan) x_tan = x_tan[~is_nan] y_tan = y_tan[~is_nan] # Estimate the spatial sampling from the tangent projection # offset from center fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_tan.shape[0] * pixel_scale_ratio x_idx = np.linspace(0, xstop, x_tan.shape[0], endpoint=False) ystop = y_tan.shape[0] * pixel_scale_ratio y_idx = np.linspace(0, ystop, y_tan.shape[0], endpoint=False) pix_to_xtan = fitter(fit_model, x_idx, x_tan) pix_to_ytan = fitter(fit_model, y_idx, y_tan) # Check whether sampling is more along RA or along Dec swap_xy = abs(pix_to_xtan.slope) < abs(pix_to_ytan.slope) log.debug(f"Swap xy: {swap_xy}") # Get output wavelengths from all data ref_lam = _find_nirspec_output_sampling_wavelengths(all_wcs) n_lam = len(ref_lam) if not n_lam: raise ValueError("Not enough data to construct output WCS.") # Find the spatial extent in x/y tangent min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent( all_wcs, undist2sky.inverse ) diff_y = np.abs(max_tan_y - min_tan_y) diff_x = np.abs(max_tan_x - min_tan_x) pix_to_tan_slope_y = np.abs(pix_to_ytan.slope) slope_sign_y = np.sign(pix_to_ytan.slope) pix_to_tan_slope_x = np.abs(pix_to_xtan.slope) slope_sign_x = np.sign(pix_to_xtan.slope) # Image size in spatial dimension from the maximum slope # and tangent offset span, plus one pixel to make sure # we catch all the data if swap_xy: ny = int(np.ceil(diff_y / pix_to_tan_slope_y)) + 1 else: ny = int(np.ceil(diff_x / pix_to_tan_slope_x)) + 1 # Correct the intercept for the new minimum value. # Also account for integer pixel size to make sure the # data is centered in the array. offset_y = ny / 2 * pix_to_tan_slope_y - diff_y / 2 offset_x = ny / 2 * pix_to_tan_slope_x - diff_x / 2 if slope_sign_y > 0: zero_value_y = min_tan_y else: zero_value_y = max_tan_y if slope_sign_x > 0: zero_value_x = min_tan_x else: zero_value_x = max_tan_x pix_to_ytan.intercept = zero_value_y - slope_sign_y * offset_y pix_to_xtan.intercept = zero_value_x - slope_sign_x * offset_x # Now set up the final transforms # For wavelengths, extrapolate 1/2 pixel at the edges and # make tabular model w/inverse pixel_coord = list(range(n_lam)) if len(pixel_coord) > 1: # left: slope = (ref_lam[1] - ref_lam[0]) / pixel_coord[1] ref_lam.insert(0, -0.5 * slope + ref_lam[0]) pixel_coord.insert(0, -0.5) # right: slope = (ref_lam[-1] - ref_lam[-2]) / (pixel_coord[-1] - pixel_coord[-2]) ref_lam.append(slope * (pixel_coord[-1] + 0.5) + ref_lam[-2]) pixel_coord.append(pixel_coord[-1] + 0.5) else: ref_lam = 3 * ref_lam pixel_coord = [-0.5, 0, 0.5] wavelength_transform = Tabular1D( points=pixel_coord, lookup_table=ref_lam, bounds_error=False, fill_value=np.nan ) # For spatial coordinates, map detector pixels to tangent offset, # then to world coordinates (RA, Dec, wavelength in um). # Make sure the inverse returns the axis with the larger slope, # in case the smaller one is close to zero mapping = Mapping((1, 1, 0)) if swap_xy: mapping.inverse = Mapping((2, 1)) else: mapping.inverse = Mapping((2, 0)) pix2world = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & wavelength_transform # For NIRSpec, slit coordinates are still needed to locate the # planned source position. Since the slit is now rectified, # return the central slit coords for all x, converting from pixels # to world coordinates, then back to slit units. slit_center = w2s(*pix2world(np.full(ny, n_lam // 2), np.arange(ny)))[1] # Make a 1D lookup table for all ny. # Allow linear extrapolation at the edges. slit_transform = Tabular1D( points=np.arange(ny), lookup_table=slit_center, bounds_error=False, fill_value=None ) # In the transform, the first slit coordinate is always set to 0 # to represent the "horizontal" center of the slit # (if we imagine the slit to be vertical in the usual # X-Y 2D cartesian frame). mapping = Mapping((0, 1, 0)) inv_mapping = Mapping((2, 1), n_inputs=3) inv_mapping.inverse = mapping mapping.inverse = inv_mapping zero_model = Const1D(0) zero_model.inverse = zero_model # Final detector to slit transform (x, y -> slit_x, slit_y, wavelength) det2slit = mapping | zero_model & slit_transform & wavelength_transform # The slit to world coordinates is just the inverse of the slit transform, # piped back into the pixel to world transform slit2world = det2slit.inverse | pix2world # Create coordinate frames: detector, slit_frame, and world det = cf.Frame2D(name="detector", axes_order=(0, 1)) slit_spatial = cf.Frame2D( name="slit_spatial", axes_order=(0, 1), unit=("", ""), axes_names=("x_slit", "y_slit") ) spec = cf.SpectralFrame( name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) ) slit_frame = cf.CompositeFrame([slit_spatial, spec], name="slit_frame") sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS()) world = cf.CompositeFrame([sky, spec], name="world") pipeline = [(det, det2slit), (slit_frame, slit2world), (world, None)] output_wcs = WCS(pipeline) # Compute bounding box and output array shape. data_size = (ny, n_lam) output_wcs.bounding_box = wcs_bbox_from_shape(data_size) output_wcs.array_shape = data_size return output_wcs
def _max_spatial_extent(self, wcs_list, transform): """ Compute spatial coordinate limits for all nods in the tangent plane. Parameters ---------- wcs_list : list List of WCS objects for all nods. transform : callable Function to convert RA, Dec to tangent plane coordinates. Returns ------- limits_x : tuple Minimum and maximum x values. limits_y : tuple Minimum and maximum y values. """ limits_x = [np.inf, -np.inf] limits_y = [np.inf, -np.inf] for wcs in wcs_list: x, y = wcstools.grid_from_bounding_box(wcs.bounding_box) ra, dec, _ = wcs(x, y) good = np.logical_and(np.isfinite(ra), np.isfinite(dec)) ra = ra[good] dec = dec[good] xtan, ytan = transform(ra, dec) for tan_all, limits in zip([xtan, ytan], [limits_x, limits_y], strict=True): min_tan = np.min(tan_all) max_tan = np.max(tan_all) if min_tan < limits[0]: limits[0] = min_tan if max_tan > limits[1]: limits[1] = max_tan return *limits_x, *limits_y
[docs] def build_interpolated_output_wcs(self, input_models, pixel_scale_ratio=1.0): """ Create a spatial/spectral WCS output frame using all the input models. Creates output frame by linearly fitting RA, Dec along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Frames available in the output WCS are: - `detector`: image x, y - `world`: RA, Dec, wavelength Parameters ---------- input_models : list List of data models, one for each input image pixel_scale_ratio : float The ratio of the input pixel scale to the output pixel scale. Returns ------- output_wcs : `~gwcs.wcs.WCS` object A GWCS object defining the output frame WCS """ # for each input model convert slit x,y to ra,dec,lam # use first input model to set spatial scale # use center of appended ra and dec arrays to set up # center of final ra,dec # append all ra,dec, wavelength array for each slit # use first model to initialize wavelength array # append wavelengths that fall outside the endpoint of # of wavelength array when looping over additional data all_wavelength = [] all_ra_slit = [] all_dec_slit = [] xstop = 0 all_wcs = [m.meta.wcs for m in input_models] for im, model in enumerate(input_models): wcs = model.meta.wcs bbox = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) ra, dec, lam = np.array(wcs(*grid)) # Handle vertical (MIRI). The following 2 variables are # 0 or 1, i.e. zero-indexed in x,y WCS order spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN") wavelength_array = np.nanmedian(lam, axis=spectral_axis) wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # We need to estimate the spatial sampling to use for the output WCS. # It is assumed the spatial sampling is the same for all the input # models. So we can use the first input model to set the spatial # sampling. # Steps to do this for first input model: # 1. Find the middle of the spectrum in wavelength # 2. Pull out the ra and dec at the center of the slit. # 3. Find the mean ra,dec and the center of the slit this will # represent the tangent point # 4. Convert ra,dec -> tangent plane projection: x_tan,y_tan # 5. using x_tan, y_tan perform a linear fit to find spatial sampling # first input model sets initializes wavelength array and defines # the spatial scale of the output wcs if im == 0: all_wavelength = np.append(all_wavelength, wavelength_array) # find the center ra and dec for this slit at central wavelength lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) if spatial_axis == 0: # MIRI LRS spectral = 1, the spatial axis = 0 ra_slice = ra[lam_center_index, :] dec_slice = dec[lam_center_index, :] else: ra_slice = ra[:, lam_center_index] dec_slice = dec[:, lam_center_index] # wrap RA if near zero ra_center_pt = np.nanmean(wrap_ra(ra_slice)) dec_center_pt = np.nanmean(dec_slice) # convert ra and dec to tangent projection tan = Pix2Sky_TAN() native2celestial = RotateNative2Celestial(ra_center_pt, dec_center_pt, 180) undist2sky1 = tan | native2celestial # Filter out RuntimeWarnings due to computed NaNs in the WCS with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) # at this center of slit find x,y tangent projection - x_tan, y_tan x_tan, y_tan = undist2sky1.inverse(ra, dec) # pull out data from center if spectral_axis == 0: x_tan_array = x_tan.T[lam_center_index] y_tan_array = y_tan.T[lam_center_index] else: # MIRI LRS Spectral Axis = 1, the WCS x axis is spatial x_tan_array = x_tan[lam_center_index] y_tan_array = y_tan[lam_center_index] x_tan_array = x_tan_array[~np.isnan(x_tan_array)] y_tan_array = y_tan_array[~np.isnan(y_tan_array)] # estimate the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_tan_array.shape[0] * pixel_scale_ratio x_idx = np.linspace(0, xstop, x_tan_array.shape[0], endpoint=False) ystop = y_tan_array.shape[0] * pixel_scale_ratio y_idx = np.linspace(0, ystop, y_tan_array.shape[0], endpoint=False) pix_to_xtan = fitter(fit_model, x_idx, x_tan_array) pix_to_ytan = fitter(fit_model, y_idx, y_tan_array) # append all ra and dec values to use later to find min and max # ra and dec ra_use = ra[~np.isnan(ra)].flatten() dec_use = dec[~np.isnan(dec)].flatten() all_ra_slit = np.append(all_ra_slit, ra_use) all_dec_slit = np.append(all_dec_slit, dec_use) # now check wavelength array to see if we need to add to it this_minw = np.min(wavelength_array) this_maxw = np.max(wavelength_array) all_minw = np.min(all_wavelength) all_maxw = np.max(all_wavelength) if this_minw < all_minw: addpts = wavelength_array[wavelength_array < all_minw] all_wavelength = np.append(all_wavelength, addpts) if this_maxw > all_maxw: addpts = wavelength_array[wavelength_array > all_maxw] all_wavelength = np.append(all_wavelength, addpts) # done looping over set of models all_ra = np.hstack(all_ra_slit) all_dec = np.hstack(all_dec_slit) all_wave = np.hstack(all_wavelength) all_wave = all_wave[~np.isnan(all_wave)] all_wave = np.sort(all_wave, axis=None) # Tabular interpolation model, pixels -> lambda wavelength_array = np.unique(all_wave) # Check if the data is MIRI LRS FIXED Slit. If it is then # the wavelength array needs to be flipped so that the resampled # dispersion direction matches the dispersion direction on the detector. if input_models[0].meta.exposure.type == "MIR_LRS-FIXEDSLIT": wavelength_array = np.flip(wavelength_array, axis=None) step = 1 stop = wavelength_array.shape[0] points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D( points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name="pix2wavelength", ) # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D( points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name="wavelength2pix", ) # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) # Sometimes the slit is perpendicular to the RA or Dec axis. # For example, if the slit is perpendicular to RA, that means # the slope of pix_to_xtan will be nearly zero, so make sure # mapping.inverse uses pix_to_ytan.inverse. The auto definition # of mapping.inverse is to use the 2nd spatial coordinate, i.e. Dec. swap_xy = abs(pix_to_xtan.slope) < abs(pix_to_ytan.slope) if swap_xy: # Account for vertical or horizontal dispersion on detector mapping.inverse = Mapping((2, 1) if spatial_axis else (1, 2)) # The final transform # redefine the ra, dec center tangent point to include all data # check if all_ra crosses 0 degrees - this makes it hard to # define the min and max ra correctly all_ra = wrap_ra(all_ra) ra_min = np.amin(all_ra) ra_max = np.amax(all_ra) ra_center_final = (ra_max + ra_min) / 2.0 dec_min = np.amin(all_dec) dec_max = np.amax(all_dec) dec_center_final = (dec_max + dec_min) / 2.0 tan = Pix2Sky_TAN() if len(input_models) == 1: # single model use ra_center_pt to be consistent # with how resample was done before ra_center_final = ra_center_pt dec_center_final = dec_center_pt native2celestial = RotateNative2Celestial(ra_center_final, dec_center_final, 180) undist2sky = tan | native2celestial ## Use all the wcs min_tan_x, max_tan_x, min_tan_y, max_tan_y = self._max_spatial_extent( all_wcs, undist2sky.inverse ) diff_y = np.abs(max_tan_y - min_tan_y) diff_x = np.abs(max_tan_x - min_tan_x) pix_to_tan_slope_y = np.abs(pix_to_ytan.slope) slope_sign_y = np.sign(pix_to_ytan.slope) pix_to_tan_slope_x = np.abs(pix_to_xtan.slope) slope_sign_x = np.sign(pix_to_xtan.slope) if swap_xy: ny = int(np.ceil(diff_y / pix_to_tan_slope_y)) else: ny = int(np.ceil(diff_x / pix_to_tan_slope_x)) offset_y = 0.5 * (ny - 1) * pix_to_tan_slope_y offset_x = 0.5 * (ny - 1) * pix_to_tan_slope_x pix_to_ytan.intercept = -slope_sign_y * offset_y pix_to_xtan.intercept = -slope_sign_x * offset_x # define the output wcs transform = mapping | (pix_to_xtan & pix_to_ytan | undist2sky) & pix_to_wavelength det = cf.Frame2D(name="detector", axes_order=(0, 1)) sky = cf.CelestialFrame(name="sky", axes_order=(0, 1), reference_frame=coord.ICRS()) spec = cf.SpectralFrame( name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) ) world = cf.CompositeFrame([sky, spec], name="world") pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # compute the output array size in WCS axes order, i.e. (x, y) output_array_size = [0, 0] output_array_size[spectral_axis] = int(np.ceil(len(wavelength_array))) output_array_size[spatial_axis] = ny # turn the size into a numpy shape in (y, x) order output_wcs.array_shape = output_array_size[::-1] output_wcs.pixel_shape = output_array_size bounding_box = wcs_bbox_from_shape(output_array_size[::-1]) output_wcs.bounding_box = bounding_box return output_wcs
[docs] def build_nirspec_lamp_output_wcs(self, input_models, pixel_scale_ratio): """ Create a spatial/spectral WCS output frame for NIRSpec lamp mode. Creates output frame by linearly fitting x_msa, y_msa along the slit and producing a lookup table to interpolate wavelengths in the dispersion direction. Frames available in the output WCS are: - `detector`: image x, y - `world`: MSA x, MSA y, wavelength Parameters ---------- input_models : list List of data models, one for each input image pixel_scale_ratio : float The ratio of the input pixel scale to the output pixel scale. Returns ------- output_wcs : `~gwcs.wcs.WCS` object A GWCS object defining the output frame WCS. """ model = input_models[0] wcs = model.meta.wcs bbox = wcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) x_msa, y_msa, lam = np.array(wcs(*grid)) # Handle vertical (MIRI) or horizontal (NIRSpec) dispersion. The # following 2 variables are 0 or 1, i.e. zero-indexed in x,y WCS order spectral_axis = find_dispersion_axis(model) spatial_axis = spectral_axis ^ 1 # Compute the wavelength array, trimming NaNs from the ends # In many cases, a whole slice is NaNs, so ignore those warnings with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) wavelength_array = np.nanmedian(lam, axis=spectral_axis) wavelength_array = wavelength_array[~np.isnan(wavelength_array)] # Find the center ra and dec for this slit at central wavelength lam_center_index = int((bbox[spectral_axis][1] - bbox[spectral_axis][0]) / 2) x_msa_array = x_msa.T[lam_center_index] y_msa_array = y_msa.T[lam_center_index] x_msa_array = x_msa_array[~np.isnan(x_msa_array)] y_msa_array = y_msa_array[~np.isnan(y_msa_array)] # Estimate and fit the spatial sampling fitter = LinearLSQFitter() fit_model = Linear1D() xstop = x_msa_array.shape[0] * pixel_scale_ratio x_idx = np.linspace(0, xstop, x_msa_array.shape[0], endpoint=False) ystop = y_msa_array.shape[0] * pixel_scale_ratio y_idx = np.linspace(0, ystop, y_msa_array.shape[0], endpoint=False) pix_to_x_msa = fitter(fit_model, x_idx, x_msa_array) pix_to_y_msa = fitter(fit_model, y_idx, y_msa_array) step = 1 stop = wavelength_array.shape[0] points = np.arange(0, stop, step) pix_to_wavelength = Tabular1D( points=points, lookup_table=wavelength_array, bounds_error=False, fill_value=None, name="pix2wavelength", ) # Tabular models need an inverse explicitly defined. # If the wavelength array is descending instead of ascending, both # points and lookup_table need to be reversed in the inverse transform # for scipy.interpolate to work properly points = wavelength_array lookup_table = np.arange(0, stop, step) if not np.all(np.diff(wavelength_array) > 0): points = points[::-1] lookup_table = lookup_table[::-1] pix_to_wavelength.inverse = Tabular1D( points=points, lookup_table=lookup_table, bounds_error=False, fill_value=None, name="wavelength2pix", ) # For the input mapping, duplicate the spatial coordinate mapping = Mapping((spatial_axis, spatial_axis, spectral_axis)) mapping.inverse = Mapping((2, 1)) # The final transform # define the output wcs transform = mapping | pix_to_x_msa & pix_to_y_msa & pix_to_wavelength det = cf.Frame2D(name="detector", axes_order=(0, 1)) sky = cf.Frame2D(name=f"resampled_{model.meta.wcs.output_frame.name}", axes_order=(0, 1)) spec = cf.SpectralFrame( name="spectral", axes_order=(2,), unit=(u.micron,), axes_names=("wavelength",) ) world = cf.CompositeFrame([sky, spec], name="world") pipeline = [(det, transform), (world, None)] output_wcs = WCS(pipeline) # Compute the output array size and bounding box output_array_size = [0, 0] output_array_size[spectral_axis] = len(wavelength_array) x_size = len(x_msa_array) output_array_size[spatial_axis] = int(np.ceil(x_size * pixel_scale_ratio)) # turn the size into a numpy shape in (y, x) order output_wcs.array_shape = output_array_size[::-1] output_wcs.pixel_shape = output_array_size bounding_box = wcs_bbox_from_shape(output_array_size[::-1]) output_wcs.bounding_box = bounding_box return output_wcs
def find_dispersion_axis(refmodel): """ Find the dispersion axis (0-indexed) of the given 2D wavelength array. Parameters ---------- refmodel : `~stdatamodels.DataModel` The input data model. Returns ------- dispaxis : int The dispersion axis (0-indexed). """ dispaxis = refmodel.meta.wcsinfo.dispersion_direction # Change from 1 --> X and 2 --> Y to 0 --> X and 1 --> Y. return dispaxis - 1 def _find_nirspec_output_sampling_wavelengths(wcs_list): refwcs = wcs_list[0] bbox = refwcs.bounding_box grid = wcstools.grid_from_bounding_box(bbox) ra, dec, lambdas = refwcs(*grid) ref_lam = sorted(np.nanmedian(lambdas[:, np.any(np.isfinite(lambdas), axis=0)], axis=0)) lam1 = ref_lam[0] lam2 = ref_lam[-1] min_delta = np.fabs(np.ediff1d(ref_lam).min()) image_lam = [] for w in wcs_list[1:]: bbox = w.bounding_box grid = wcstools.grid_from_bounding_box(bbox) ra, dec, lambdas = w(*grid) lam = sorted(np.nanmedian(lambdas[:, np.any(np.isfinite(lambdas), axis=0)], axis=0)) image_lam.append((lam, np.min(lam), np.max(lam))) min_delta = min(min_delta, np.fabs(np.ediff1d(ref_lam).min())) # The code below is optimized for the case when wavelength is an increasing # function of the pixel index along the X-axis. It will not work correctly # if this assumption does not hold. # Estimate overlaps between ranges and decide in which order to combine # them: while image_lam: best_overlap = -np.inf best_wcs = 0 for k, (_lam, lmin, lmax) in enumerate(image_lam): overlap = min(lam2, lmax) - max(lam1, lmin) if best_overlap < overlap: best_overlap = overlap best_wcs = k lam, lmin, lmax = image_lam.pop(best_wcs) if lmax < lam1: ref_lam = lam + ref_lam lam1 = lmin elif lmin > lam2: ref_lam.extend(lam) lam2 = lmax else: lam_ar = np.array(lam) if lmin < lam1: idx = np.flatnonzero(lam_ar < lam1) ref_lam = lam_ar[idx].tolist() + ref_lam lam1 = ref_lam[0] if lmax > lam2: idx = np.flatnonzero(lam_ar > lam2) ref_lam = ref_lam + lam_ar[idx].tolist() lam2 = ref_lam[-1] # In the resampled WCS, if two wavelengths are closer to each other # than 1/10 of the minimum difference between two wavelengths, # remove one of the points. ediff = np.fabs(np.ediff1d(ref_lam)) idx = np.flatnonzero(ediff < max(0.1 * min_delta, 1e2 * np.finfo(1.0).eps)) for i in idx[::-1]: del ref_lam[i] return ref_lam def compute_spectral_pixel_scale(wcs, fiducial=None, disp_axis=1): """ Compute an approximate spatial pixel scale for spectral data. Parameters ---------- wcs : `gwcs.wcs.WCS` Spatial/spectral WCS. fiducial : tuple of float, optional (RA, Dec, wavelength) taken as the fiducial reference. If not specified, the center of the array is used. disp_axis : int Dispersion axis for the data. Assumes the same convention as `wcsinfo.dispersion_direction` (1 for NIRSpec, 2 for MIRI). Returns ------- pixel_scale : float The spatial scale in degrees. """ # Get the coordinates for the center of the array if fiducial is None: center_x, center_y = np.mean(wcs.bounding_box, axis=1) fiducial = wcs(center_x, center_y) pixel_scale = compute_scale(wcs, fiducial, disp_axis=disp_axis) return float(pixel_scale)