Source code for specreduce.wavelength_calibration

from functools import cached_property

import numpy as np
from astropy import units as u
from astropy.modeling.fitting import LMLSQFitter, LinearLSQFitter
from astropy.modeling.models import Linear1D
from astropy.table import QTable, hstack
from gwcs import coordinate_frames as cf
from gwcs import wcs
from specutils import Spectrum1D

__all__ = [
    'WavelengthCalibration1D'
]


def _check_arr_monotonic(arr):
    # returns True if ``arr`` is either strictly increasing or strictly
    # decreasing, otherwise returns False.

    sorted_increasing = np.all(arr[1:] >= arr[:-1])
    sorted_decreasing = np.all(arr[1:] <= arr[:-1])
    return sorted_increasing or sorted_decreasing


[docs] class WavelengthCalibration1D(): def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, line_wavelengths=None, catalog=None, input_model=Linear1D(), fitter=None): """ input_spectrum: `~specutils.Spectrum1D` A one-dimensional Spectrum1D calibration spectrum from an arc lamp or similar. matched_line_list: `~astropy.table.QTable`, optional An `~astropy.table.QTable` table with (minimally) columns named "pixel_center" and "wavelength" with known corresponding line pixel centers and wavelengths populated. line_pixels: list, array, `~astropy.table.QTable`, optional List or array of line pixel locations to anchor the wavelength solution fit. Can also be input as an `~astropy.table.QTable` table with (minimally) a column named "pixel_center". line_wavelengths: `~astropy.units.Quantity`, `~astropy.table.QTable`, optional `astropy.units.Quantity` array of line wavelength values corresponding to the line pixels defined in ``line_list``, assumed to be in the same order. Can also be input as an `~astropy.table.QTable` with (minimally) a "wavelength" column. catalog: list, str, `~astropy.table.QTable`, optional The name of a catalog of line wavelengths to load and use in automated and template-matching line matching. NOTE: This option is currently not implemented. input_model: `~astropy.modeling.Model` The model to fit for the wavelength solution. Defaults to a linear model. fitter: `~astropy.modeling.fitting.Fitter`, optional The fitter to use in optimizing the model fit. Defaults to `~astropy.modeling.fitting.LinearLSQFitter` if the model to fit is linear or `~astropy.modeling.fitting.LMLSQFitter` if the model to fit is non-linear. Note that either ``matched_line_list`` or ``line_pixels`` must be specified, and if ``matched_line_list`` is not input, at least one of ``line_wavelengths`` or ``catalog`` must be specified. """ self._input_spectrum = input_spectrum self._input_model = input_model self._cached_properties = ['fitted_model', 'residuals', 'wcs'] self.fitter = fitter self._potential_wavelengths = None self._catalog = catalog if not isinstance(input_spectrum, Spectrum1D): raise ValueError('Input spectrum must be Spectrum1D.') # We use either line_pixels or matched_line_list to create self._matched_line_list, # and check that various requirements are fulfilled by the input args. if matched_line_list is not None: pixel_arg = "matched_line_list" if not isinstance(matched_line_list, QTable): raise ValueError("matched_line_list must be an astropy.table.QTable.") self._matched_line_list = matched_line_list elif line_pixels is not None: pixel_arg = "line_pixels" if isinstance(line_pixels, (list, np.ndarray)): self._matched_line_list = QTable([line_pixels], names=["pixel_center"]) elif isinstance(line_pixels, QTable): self._matched_line_list = line_pixels else: raise ValueError("Either matched_line_list or line_pixels must be specified.") if "pixel_center" not in self._matched_line_list.columns: raise ValueError(f"{pixel_arg} must have a 'pixel_center' column.") if self._matched_line_list["pixel_center"].unit is None: self._matched_line_list["pixel_center"].unit = u.pix # check that pixels are monotonic if not _check_arr_monotonic(self._matched_line_list["pixel_center"]): raise ValueError('Pixels must be strictly increasing or decreasing.') # now that pixels have been determined from input, figure out wavelengths. if (line_wavelengths is None and catalog is None and "wavelength" not in self._matched_line_list.columns): raise ValueError("You must specify at least one of line_wavelengths, " "catalog, or 'wavelength' column in matched_line_list.") # Sanity checks on line_wavelengths value if line_wavelengths is not None: if (isinstance(self._matched_line_list, QTable) and "wavelength" in self._matched_line_list.columns): raise ValueError("Cannot specify line_wavelengths separately if there is" " a 'wavelength' column in matched_line_list.") if len(line_wavelengths) != len(self._matched_line_list): raise ValueError("If line_wavelengths is specified, it must have the same " f"length as {pixel_arg}") if not isinstance(line_wavelengths, (u.Quantity, QTable)): raise ValueError("line_wavelengths must be specified as an astropy.units.Quantity" " array or as an astropy.table.QTable") # make sure wavelengths (or freq) are monotonic and add wavelengths # to _matched_line_list if isinstance(line_wavelengths, u.Quantity): if not _check_arr_monotonic(line_wavelengths): if str(line_wavelengths.unit.physical_type) == "frequency": raise ValueError('Frequencies must be strictly increasing or decreasing.') raise ValueError('Wavelengths must be strictly increasing or decreasing.') self._matched_line_list["wavelength"] = line_wavelengths elif isinstance(line_wavelengths, QTable): if not _check_arr_monotonic(line_wavelengths['wavelength']): raise ValueError('Wavelengths must be strictly increasing or decreasing.') self._matched_line_list = hstack([self._matched_line_list, line_wavelengths]) # Parse desired catalogs of lines for matching. if catalog is not None: # For now we avoid going into the later logic and just throw an error raise NotImplementedError("No catalogs are available yet, please input " "wavelengths with line_wavelengths or as a " f"column in {pixel_arg}") if isinstance(catalog, QTable): if "wavelength" not in catalog.columns: raise ValueError("Catalog table must have a 'wavelength' column.") self._catalog = catalog else: # This will need to be updated to match up with Tim's catalog code if isinstance(catalog, list): self._catalog = catalog else: self._catalog = [catalog] for cat in self._catalog: if isinstance(cat, str): if cat not in self._available_catalogs: raise ValueError(f"Line list '{cat}' is not an available catalog.")
[docs] def identify_lines(self): """ ToDo: Code matching algorithm between line pixel locations and potential line wavelengths from catalogs. """ pass
def _clear_cache(self, *attrs): """ provide convenience function to clearing the cache for cached_properties """ if not len(attrs): attrs = self._cached_properties for attr in attrs: if attr in self.__dict__: del self.__dict__[attr] @property def available_catalogs(self): return self._available_catalogs @property def input_spectrum(self): return self._input_spectrum @input_spectrum.setter def input_spectrum(self, new_spectrum): # We want to clear the refined locations if a new calibration spectrum is provided self._clear_cache() self._input_spectrum = new_spectrum @property def input_model(self): return self._input_model @input_model.setter def input_model(self, input_model): self._clear_cache() self._input_model = input_model @cached_property def fitted_model(self): # computes and returns WCS after fitting self.model to self.refined_pixels x = self._matched_line_list["pixel_center"] y = self._matched_line_list["wavelength"] if self.fitter is None: # Flexible defaulting if self.fitter is None if self.input_model.linear: fitter = LinearLSQFitter(calc_uncertainties=True) else: fitter = LMLSQFitter(calc_uncertainties=True) else: fitter = self.fitter # Fit the model return fitter(self.input_model, x, y) @cached_property def residuals(self): """ calculate fit residuals between matched line list pixel centers and wavelengths and the evaluated fit model. """ x = self._matched_line_list["pixel_center"] y = self._matched_line_list["wavelength"] # Get the fit residuals by evaulating model return y - self.fitted_model(x) @cached_property def wcs(self): # Build a GWCS pipeline from the fitted model pixel_frame = cf.CoordinateFrame(1, "SPECTRAL", [0,], axes_names=["x",], unit=[u.pix,]) spectral_frame = cf.SpectralFrame(axes_names=["wavelength",], unit=[self._matched_line_list["wavelength"].unit,]) pipeline = [(pixel_frame, self.fitted_model), (spectral_frame, None)] wcsobj = wcs.WCS(pipeline) return wcsobj
[docs] def apply_to_spectrum(self, spectrum=None): # returns spectrum1d with wavelength calibration applied # actual line refinement and WCS solution should already be done so that this can # be called on multiple science sources spectrum = self.input_spectrum if spectrum is None else spectrum updated_spectrum = Spectrum1D(spectrum.flux, wcs=self.wcs, mask=spectrum.mask, uncertainty=spectrum.uncertainty) return updated_spectrum