import warnings
from typing import Sequence, Literal
import matplotlib.pyplot as plt
import numpy as np
from astropy.modeling import models, fitting
from astropy.nddata import StdDevUncertainty, NDData
from numpy import ndarray, repeat, tile
from scipy.optimize import minimize
from scipy.spatial import KDTree
from specutils import Spectrum
from specreduce.core import _ImageParser
from specreduce.line_matching import find_arc_lines
from specreduce.tilt_solution import TiltSolution
from specreduce.tracing import Trace
__all__ = ["TiltCorrection"]
[docs]
class TiltCorrection:
def __init__(
self,
arc_frames: NDData | Sequence[NDData],
trace: Trace | None = None,
cdisp_ref_pixel: float | None = None,
disp_ref_pixel: float | None = None,
n_cdisp_samples: int = 10,
cdisp_sample_lims: tuple[float, float] | None = None,
cdisp_samples: Sequence[float] | None = None,
disp_axis: int = 1,
mask_treatment: Literal[
"apply",
"ignore",
"propagate",
"zero_fill",
"nan_fill",
"apply_mask_only",
"apply_nan_only",
] = "apply",
):
"""A class for 2D spectral tilt correction.
This class provides tools for correcting spectral tilt (curvature) in 2D
spectroscopic data using arc lamp frames. It identifies arc lines across the
cross-dispersion axis and fits a 2D polynomial transformation from a
tilt-corrected (rectified) coordinate space to detector space. The resulting
`~specreduce.tilt_solution.TiltSolution` can then be used to resample science
frames onto a rectified grid.
Parameters
----------
arc_frames
A sequence of arc frames as `~astropy.nddata.NDData` instances.
trace
A trace object representing the spectrum trace. If provided, it will be used to
determine the reference positions along the dispersion and cross-dispersion axes.
cdisp_ref_pixel
A reference pixel position along the cross-dispersion axis. Should be close to the
spectrum trace's average cross-dispersion position for the best results.
disp_ref_pixel
A reference pixel position along the dispersion axis. Should be close to the
center of the frame along the dispersion axis for best results.
n_cdisp_samples
Number of cross-dispersion (CD) samples to use.
cdisp_sample_lims
Tuple specifying the limits for calculating cross-dispersion sampling.
cdisp_samples
A list of cross-dispersion locations to use. Overrides ``n_cdisp_samples``
if provided.
disp_axis
The index of the image's dispersion axis.
mask_treatment
Specifies how to handle masked or non-finite values in the input image.
The accepted values are:
- ``apply``: The image remains unchanged, and any existing mask is combined with a mask
derived from non-finite values.
- ``ignore``: The image remains unchanged, and any existing mask is dropped.
- ``propagate``: The image remains unchanged, and any masked or non-finite pixel
causes the mask to extend across the entire cross-dispersion axis.
- ``zero_fill``: Pixels that are either masked or non-finite are replaced with 0.0,
and the mask is dropped.
- ``nan_fill``: Pixels that are either masked or non-finite are replaced with nan,
and the mask is dropped.
- ``apply_mask_only``: The image and mask are left unmodified.
- ``apply_nan_only``: The image is left unmodified, the old mask is dropped,
and a new mask is created based on non-finite values.
"""
self.disp_axis = disp_axis
self.mask_treatment = mask_treatment
# IMPLEMENTATION NOTES: the code assumes that the image-parsing routines ensure that the
# cross-dispersion axis is aligned with the y-axis (1st array dimension) and the
# dispersion axis with the x-axis (2nd array dimension). However, this should not be
# visible to the end-user. The rectified spectra are returned with the original axis
# alignment given by the ``disp_axis`` argument. Also, I've decided to use `x` and `y`
# naming instead of `col` and `row` because this leads to (slightly) more readable code.
# The methods that are visible to the end-user use `disp` and `cdisp` naming. -HP
if not isinstance(arc_frames, Sequence):
arc_frames = [arc_frames]
# An ugly hack that should be changed after the refactoring of image parsing.
ip = _ImageParser()
self.arc_frames = []
for f in arc_frames:
im = ip._parse_image(f, disp_axis=disp_axis, mask_treatment=mask_treatment)
self.arc_frames.append(NDData(im.flux, uncertainty=im.uncertainty, mask=im.mask))
self.nframes = len(arc_frames)
self._ny, self._nx = self.arc_frames[0].data.shape
self._lines_ref: Sequence[ndarray] | None = None
self._samples_rec_x: Sequence[ndarray] | None = None
self._samples_rec_y: Sequence[ndarray] | None = None
self._samples_det_x: Sequence[ndarray] | None = None
self._samples_det_y: Sequence[ndarray] | None = None
self._arc_spectra: Sequence[Spectrum] | None = None
self._trees: Sequence[KDTree] | None = None
if trace is not None:
self.trace = trace
disp_ref_pixel = self.trace.trace.size // 2
cdisp_ref_pixel = int(self.trace.trace[disp_ref_pixel])
else:
if cdisp_ref_pixel is None:
raise ValueError("cdisp_ref_position must be provided if trace is not provided.")
if disp_ref_pixel is None:
disp_ref_pixel = self.arc_frames[0].data.shape[disp_axis] // 2
self.ref_pixel = (cdisp_ref_pixel, disp_ref_pixel) # Reference pixel (y, x)
self._shift = models.Shift(-self.ref_pixel[1]) & models.Shift(-self.ref_pixel[0])
# Calculate the cross-dispersion axis sample positions
slims = cdisp_sample_lims if cdisp_sample_lims is not None else (0, self._ny)
if cdisp_samples is not None:
self.cd_samples = np.array(cdisp_samples)
else:
self.cd_samples = slims[0] + np.round(
np.arange(1, n_cdisp_samples + 1) * (slims[1] - slims[0]) / (n_cdisp_samples + 1)
).astype(int)
self.ncd = self.cd_samples.size
self.solution: TiltSolution | None = None
[docs]
def find_arc_lines(self, fwhm: float, noise_factor: float = 5.0) -> None:
"""Find arc lines from the provided arc frames for all cross-dispersion samples.
This method locates spectral arc lines from the provided arc frames, calculates
their centroids, and organizes them into reference lists and sample arrays
for further analysis.
Parameters
----------
fwhm
Full width at half maximum of the spectral line to be detected, used
by the line-finding algorithm.
noise_factor
A multiplier for noise thresholding in the line-finding process.
"""
self._arc_spectra = []
self._samples_rec_x = []
self._lines_ref = []
samples_x = []
samples_y = []
with warnings.catch_warnings():
warnings.simplefilter("ignore")
for i, d in enumerate(self.arc_frames):
self._arc_spectra.append([])
samples_x.append([])
samples_y.append([])
# Find the line centroids for the reference row
spectrum = Spectrum(
d.data[self.ref_pixel[0]] * d.unit,
uncertainty=d.uncertainty[self.ref_pixel[0]].represent_as(StdDevUncertainty),
)
lines = find_arc_lines(spectrum, fwhm, noise_factor=noise_factor)
self._lines_ref.append(lines["centroid"].value)
# Find the line centroids for the sample rows
for s in self.cd_samples:
spectrum = Spectrum(
d.data[s] * d.unit,
uncertainty=d.uncertainty[s].represent_as(StdDevUncertainty),
)
lines = find_arc_lines(spectrum, fwhm, noise_factor=noise_factor)
samples_x[i].append(lines["centroid"].value)
samples_y[i].append(np.full(len(lines), s))
self._arc_spectra[i].append(spectrum)
self._samples_det_x = [np.concatenate(lpx) for lpx in samples_x]
self._samples_det_y = [np.concatenate(lpy) for lpy in samples_y]
self._samples_rec_y = [repeat(self.cd_samples, lref.size) for lref in self._lines_ref]
self._samples_rec_x = [tile(lref, self.cd_samples.size) for lref in self._lines_ref]
self._trees = [
KDTree(np.vstack([lx, ly]).T)
for lx, ly in zip(self._samples_det_x, self._samples_det_y)
]
[docs]
def fit(
self, degree: int = 3, method: str = "Powell", max_distance: float = 10
) -> TiltSolution:
"""Fit a 2D polynomial transformation from tilt-corrected space to detector space.
The transformation is calculated by minimizing the sum of distances between transformed
samples and their corresponding detector-space targets. The minimization is performed in
two stages: an initial minimization of a kd-tree based sum of line-line distances using
`scipy.optimize.minimize` and a refinement using least-squares optimization of matched
lines.
Parameters
----------
degree
The degree of the final 2D polynomial model.
method
The optimization method used during the initial fitting stage.
max_distance
The maximum allowable distance to constrain the minimization.
"""
model = self._shift | models.Polynomial2D(3)
coeffs = np.zeros(10)
coeffs[0] = self.ref_pixel[1]
coeffs[1] = 1
transformed_points = [tile(a, (2, 1)).T.astype("d") for a in self._samples_rec_y]
def minfun(x):
coeffs[4:] = x
total_distance = 0.0
for i, t in enumerate(self._trees):
transformed_points[i][:, 0] = model.evaluate(
self._samples_rec_x[i],
self._samples_rec_y[i],
-self.ref_pixel[1],
-self.ref_pixel[0],
*coeffs,
)
total_distance += np.clip(t.query(transformed_points[i])[0], 0, max_distance).sum()
return total_distance
res = minimize(minfun, np.zeros(6), method=method)
coeffs[4:] = res.x
self.solution = TiltSolution(
self._shift
| models.Polynomial2D(
model[-1].degree,
**{model[-1].param_names[i]: coeffs[i] for i in range(coeffs.size)},
),
image_shape=(self._ny, self._nx),
)
# Calculate the final fit using least-squares optimization between matched lines
self.refine_fit(degree)
return self.solution
[docs]
def refine_fit(self, degree: int = 4, match_distance_bound: float = 5.0) -> None:
"""Refine the tilt-corrected space -> detector space transformation model parameters.
Refines the polynomial fit model parameters for matching features with a specified
degree and match distance bound. The refinement includes matching lines, updating a
polynomial model, and optimizing the parameters using a least squares fitter
The derivative is recalculated after the optimization.
Parameters
----------
degree
Degree of the polynomial used in the Polynomial2D model.
match_distance_bound
Maximum acceptable distance between features to be considered a match.
"""
if self.solution is None:
raise ValueError("The solution must be calculated before it can be refined.")
rx, ry, ox = self.match_lines(match_distance_bound)
model = self._shift | models.Polynomial2D(
degree,
**{
n: getattr(self.solution.c2d[-1], n).value
for n in self.solution.c2d[-1].param_names
},
)
model.offset_0.fixed = True
model.offset_1.fixed = True
for i in range(degree + 1):
model.fixed[f"c{i}_0_2"] = True
fitter = fitting.LMLSQFitter()
self.solution.c2d = fitter(model, rx, ry, ox)
[docs]
def match_lines(
self, max_distance: float = 5, concatenate: bool = True
) -> tuple[ndarray, ndarray, ndarray] | tuple[list[ndarray], list[ndarray], list[ndarray]]:
"""Match the reference arc line locations with the detector-space targets.
Parameters
----------
max_distance
Specifies the maximum allowed distance for matching lines. Matches beyond this distance
will be ignored.
concatenate
Specifies whether to concatenate the matched lines.
Returns
-------
tuple of numpy.ndarray
A tuple containing three concatenated numpy arrays representing:
- x-coordinates of matched rectified-space lines.
- y-coordinates of matched rectified-space lines.
- x-coordinates of matched detector-space lines.
"""
if self.solution is None:
raise ValueError("The solution must be calculated before line matching.")
matched_det_x = []
matched_rec_x = []
matched_rec_y = []
for iframe, tree in enumerate(self._trees):
x_mapped = self.solution.c2d(self._samples_rec_x[iframe], self._samples_rec_y[iframe])
l, ix = tree.query(
np.array([x_mapped, self._samples_rec_y[iframe]]).T,
distance_upper_bound=max_distance,
)
m = np.isfinite(l)
matched_det_x.append(tree.data[ix[m], 0])
matched_rec_x.append(self._samples_rec_x[iframe][m])
matched_rec_y.append(self._samples_rec_y[iframe][m])
if concatenate:
return (
np.concatenate(matched_rec_x),
np.concatenate(matched_rec_y),
np.concatenate(matched_det_x),
)
else:
return matched_rec_x, matched_rec_y, matched_det_x
[docs]
def plot_wavelength_contours(
self,
n_disp: int = 50,
n_cdisp: int = 100,
disp_values: Sequence[float] | None = None,
ax: plt.Axes | None = None,
figsize: tuple[float, float] | None = None,
line_args: dict | None = None,
):
"""Plot wavelength contour lines in detector space.
Parameters
----------
n_disp
The number of dispersion-axis lines.
n_cdisp
The number of cross-dispersion axis points for each disp-axis line.
disp_values
A sequence specifying the dispersion-axis coordinates explicitly. If not
provided, it will be automatically calculated based on the arc frame
dimensions.
ax
The Matplotlib Axes on which to plot. If None, a new figure and Axes
are created.
figsize
Tuple specifying the size of the figure to create, applicable only if
``ax`` is None.
line_args
A dictionary of line properties (e.g., color, linewidth, linestyle).
These properties modify the default styling provided for grid lines.
If None, default styles are used. Default is None.
Returns
-------
figure : matplotlib.figure.Figure
The Matplotlib figure containing the plot. If an Axes instance is passed
to ``ax``, the associated figure is returned.
"""
largs = {"c": "k", "lw": 0.5, "alpha": 0.5, "ls": "--"}
if line_args is not None:
largs.update(line_args)
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
if disp_values is None:
disp_values = tile(
np.linspace(0, self.arc_frames[0].data.shape[1], n_disp), (n_cdisp, 1)
)
else:
n_disp = len(disp_values)
rows = tile(np.linspace(0, self.arc_frames[0].data.shape[0], n_cdisp)[:, None], (1, n_disp))
ax.plot(self.solution.c2d(disp_values, rows), rows, **largs)
return fig
[docs]
def plot_fit_quality(
self, figsize=None, max_match_distance: float = 5, rlim: tuple[float, float] | None = None
):
"""Plot fit quality diagnostics showing residuals of the tilt correction.
Creates a three-panel figure with a scatter plot of matched line positions
and marginal residual plots along the dispersion and cross-dispersion axes.
Parameters
----------
figsize
Tuple specifying the size of the figure. If None, the default Matplotlib
figure size is used.
max_match_distance
Maximum distance for matching lines, passed to `match_lines`.
rlim
Residual axis limits as a tuple (min, max). Applied to both marginal
residual plots. If None, limits are set automatically.
Returns
-------
matplotlib.figure.Figure
The Matplotlib figure containing the diagnostic plots.
"""
fig = plt.Figure(figsize=figsize, layout="constrained")
gs = plt.GridSpec(2, 2, width_ratios=(4, 1), height_ratios=(1, 3), figure=fig)
ax1 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[0, 0])
ax3 = fig.add_subplot(gs[1, 1])
rxs, rys, dxs = self.match_lines(max_match_distance, concatenate=False)
for i, (rx, ry, dx) in enumerate(zip(rxs, rys, dxs)):
residuals = dx - self.solution.corr_to_det(rx, ry)[0]
ax1.scatter(rx, ry, s=50 * abs(residuals), label=f"Arc {i+1}")
ax2.plot(rx, residuals, ".")
ax3.plot(residuals, ry, ".")
ax1.legend(loc="upper right")
ax2.set_xlim(ax1.get_xlim())
ax3.set_ylim(ax1.get_ylim())
plt.setp(ax2.get_xticklabels(), visible=False)
plt.setp(ax3.get_yticklabels(), visible=False)
plt.setp(ax1, xlabel="Dispersion axis [pix]", ylabel="Cross-dispersion axis [pix]")
ax2.set_ylabel("Residuals [pix]")
ax3.set_xlabel("Residuals [pix]")
if rlim is not None:
ax2.set_ylim(rlim)
ax3.set_xlim(rlim)
return fig