# Licensed under a 3-clause BSD style license - see LICENSE.rst
import warnings
from dataclasses import dataclass, field
import numpy as np
from astropy import units as u
from astropy.nddata import NDData
from astropy.utils.decorators import deprecated_attribute
from specutils import Spectrum1D
from specreduce.core import _ImageParser
from specreduce.extract import _ap_weight_image
from specreduce.tracing import Trace, FlatTrace
__all__ = ['Background']
[docs]
@dataclass
class Background(_ImageParser):
"""
Determine the background from an image for subtraction.
Example: ::
trace = FlatTrace(image, trace_pos)
bg = Background.two_sided(image, trace, bkg_sep, width=bkg_width)
subtracted_image = image - bg
Parameters
----------
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
traces : trace, int, float (single or list)
Individual or list of trace object(s) (or integers/floats to define
FlatTraces) to extract the background. If None, a FlatTrace at the
center of the image (according to `disp_axis`) will be used.
width : float
width of extraction aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
# required so numpy won't call __rsub__ on individual elements
# https://stackoverflow.com/a/58409215
__array_ufunc__ = None
image: NDData
traces: list = field(default_factory=list)
width: float = 5
statistic: str = 'average'
disp_axis: int = 1
crossdisp_axis: int = 0
# TO-DO: update bkg_array with Spectrum1D alternative (is bkg_image enough?)
bkg_array = deprecated_attribute('bkg_array', '1.3')
def __post_init__(self):
"""
Determine the background from an image for subtraction.
Parameters
----------
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
extract the background
width : float
width of each background aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
self.image = self._parse_image(self.image)
if self.width < 0:
raise ValueError("width must be positive")
if self.width == 0:
self._bkg_array = np.zeros(self.image.shape[self.disp_axis])
return
self._set_traces()
bkg_wimage = np.zeros_like(self.image.data, dtype=np.float64)
for trace in self.traces:
windows_max = trace.trace.data.max() + self.width/2
windows_min = trace.trace.data.min() - self.width/2
if windows_max >= self.image.shape[self.crossdisp_axis]:
warnings.warn("background window extends beyond image boundaries " +
f"({windows_max} >= {self.image.shape[self.crossdisp_axis]})")
if windows_min < 0:
warnings.warn("background window extends beyond image boundaries " +
f"({windows_min} < 0)")
# pass trace.trace.data to ignore any mask on the trace
bkg_wimage += _ap_weight_image(trace,
self.width,
self.disp_axis,
self.crossdisp_axis,
self.image.shape)
if np.any(bkg_wimage > 1):
raise ValueError("background regions overlapped")
if np.any(np.sum(bkg_wimage, axis=self.crossdisp_axis) == 0):
raise ValueError("background window does not remain in bounds across entire dispersion axis") # noqa
if self.statistic == 'median':
# make it clear in the expose image that partial pixels are fully-weighted
bkg_wimage[bkg_wimage > 0] = 1
self.bkg_wimage = bkg_wimage
# mask user-highlighted and invalid values (if any) before taking stats
or_mask = (np.logical_or(~np.isfinite(self.image.data), self.image.mask)
if self.image.mask is not None
else ~np.isfinite(self.image.data))
if self.statistic == 'average':
image_ma = np.ma.masked_array(self.image.data, mask=or_mask)
self._bkg_array = np.ma.average(image_ma,
weights=self.bkg_wimage,
axis=self.crossdisp_axis).data
elif self.statistic == 'median':
med_mask = np.logical_or(self.bkg_wimage == 0, or_mask)
image_ma = np.ma.masked_array(self.image.data, mask=med_mask)
self._bkg_array = np.ma.median(image_ma, axis=self.crossdisp_axis).data
else:
raise ValueError("statistic must be 'average' or 'median'")
def _set_traces(self):
"""Determine `traces` from input. If an integer/float or list if int/float
is passed in, use these to construct FlatTrace objects. These values
must be positive. If None (which is initialized to an empty list),
construct a FlatTrace using the center of image (according to disp.
axis). Otherwise, any Trace object or list of Trace objects can be
passed in."""
if self.traces == []:
# assume a flat trace at the image center if nothing is passed in.
trace_pos = self.image.shape[self.disp_axis] / 2.
self.traces = [FlatTrace(self.image, trace_pos)]
if isinstance(self.traces, Trace):
# if just one trace, turn it into iterable.
self.traces = [self.traces]
return
# finally, if float/int is passed in convert to FlatTrace(s)
if isinstance(self.traces, (float, int)): # for a single number
self.traces = [self.traces]
if np.all([isinstance(x, (float, int)) for x in self.traces]):
self.traces = [FlatTrace(self.image, trace_pos) for trace_pos in self.traces]
return
else:
if not np.all([isinstance(x, Trace) for x in self.traces]):
raise ValueError('`traces` must be a `Trace` object or list of '
'`Trace` objects, a number or list of numbers to '
'define FlatTraces, or None to use a FlatTrace in '
'the middle of the image.')
[docs]
@classmethod
def two_sided(cls, image, trace_object, separation, **kwargs):
"""
Determine the background from an image for subtraction centered around
an input trace.
Example: ::
trace = FitTrace(image, guess=trace_pos)
bg = Background.two_sided(image, trace, bkg_sep, width=bkg_width)
Parameters
----------
image : `~astropy.nddata.NDData`-like or array-like
Image with 2-D spectral image data. Assumes cross-dispersion
(spatial) direction is axis 0 and dispersion (wavelength)
direction is axis 1.
trace_object: `~specreduce.tracing.Trace`
estimated trace of the spectrum to center the background traces
separation: float
separation from ``trace_object`` for the background regions
width : float
width of each background aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
image = _ImageParser._get_data_from_image(image) if image is not None else cls.image
kwargs['traces'] = [trace_object-separation, trace_object+separation]
return cls(image=image, **kwargs)
[docs]
@classmethod
def one_sided(cls, image, trace_object, separation, **kwargs):
"""
Determine the background from an image for subtraction above
or below an input trace.
Example: ::
trace = FitTrace(image, guess=trace_pos)
bg = Background.one_sided(image, trace, bkg_sep, width=bkg_width)
Parameters
----------
image : `~astropy.nddata.NDData`-like or array-like
Image with 2-D spectral image data. Assumes cross-dispersion
(spatial) direction is axis 0 and dispersion (wavelength)
direction is axis 1.
trace_object: `~specreduce.tracing.Trace`
estimated trace of the spectrum to center the background traces
separation: float
separation from ``trace_object`` for the background, positive will be
above the trace, negative below.
width : float
width of each background aperture in pixels
statistic: string
statistic to use when computing the background. 'average' will
account for partial pixel weights, 'median' will include all partial
pixels.
disp_axis : int
dispersion axis
crossdisp_axis : int
cross-dispersion axis
"""
image = _ImageParser._get_data_from_image(image) if image is not None else cls.image
kwargs['traces'] = [trace_object+separation]
return cls(image=image, **kwargs)
[docs]
def bkg_image(self, image=None):
"""
Expose the background tiled to the dimension of ``image``.
Parameters
----------
image : `~astropy.nddata.NDData`-like or array-like, optional
Image with 2-D spectral image data. Assumes cross-dispersion
(spatial) direction is axis 0 and dispersion (wavelength)
direction is axis 1. If None, will extract the background
from ``image`` used to initialize the class. [default: None]
Returns
-------
`~specutils.Spectrum1D` object with same shape as ``image``.
"""
image = self._parse_image(image)
return Spectrum1D(np.tile(self._bkg_array,
(image.shape[0], 1)) * image.unit,
spectral_axis=image.spectral_axis)
[docs]
def bkg_spectrum(self, image=None):
"""
Expose the 1D spectrum of the background.
Parameters
----------
image : `~astropy.nddata.NDData`-like or array-like, optional
Image with 2-D spectral image data. Assumes cross-dispersion
(spatial) direction is axis 0 and dispersion (wavelength)
direction is axis 1. If None, will extract the background
from ``image`` used to initialize the class. [default: None]
Returns
-------
spec : `~specutils.Spectrum1D`
The background 1-D spectrum, with flux expressed in the same
units as the input image (or u.DN if none were provided) and
the spectral axis expressed in pixel units.
"""
bkg_image = self.bkg_image(image)
try:
return bkg_image.collapse(np.nansum, axis=self.crossdisp_axis)
except u.UnitTypeError:
# can't collapse with a spectral axis in pixels because
# SpectralCoord only allows frequency/wavelength equivalent units...
ext1d = np.nansum(bkg_image.flux, axis=self.crossdisp_axis)
return Spectrum1D(ext1d, bkg_image.spectral_axis)
[docs]
def sub_image(self, image=None):
"""
Subtract the computed background from ``image``.
Parameters
----------
image : nddata-compatible image or None
image with 2-D spectral image data. If None, will extract
the background from ``image`` used to initialize the class.
Returns
-------
`~specutils.Spectrum1D` object with same shape as ``image``.
"""
image = self._parse_image(image)
# a compare_wcs argument is needed for Spectrum1D.subtract() in order to
# avoid a TypeError from SpectralCoord when image's spectral axis is in
# pixels. it is not needed when image's spectral axis has physical units
kwargs = ({'compare_wcs': None} if image.spectral_axis.unit == u.pix
else {})
# https://docs.astropy.org/en/stable/nddata/mixins/ndarithmetic.html
return image.subtract(self.bkg_image(image), **kwargs)
[docs]
def sub_spectrum(self, image=None):
"""
Expose the 1D spectrum of the background-subtracted image.
Parameters
----------
image : nddata-compatible image or None
image with 2-D spectral image data. If None, will extract
the background from ``image`` used to initialize the class.
Returns
-------
spec : `~specutils.Spectrum1D`
The background 1-D spectrum, with flux expressed in the same
units as the input image (or u.DN if none were provided) and
the spectral axis expressed in pixel units.
"""
sub_image = self.sub_image(image=image)
try:
return sub_image.collapse(np.nansum, axis=self.crossdisp_axis)
except u.UnitTypeError:
# can't collapse with a spectral axis in pixels because
# SpectralCoord only allows frequency/wavelength equivalent units...
ext1d = np.nansum(sub_image.flux, axis=self.crossdisp_axis)
return Spectrum1D(ext1d, spectral_axis=sub_image.spectral_axis)
def __rsub__(self, image):
"""
Subtract the background from an image.
"""
return self.sub_image(image)