# Licensed under a 3-clause BSD style license - see ../../licenses/LICENSE.rst
import warnings
from dataclasses import dataclass, field
import numpy as np
from astropy import units as u
from astropy.modeling import models, Model
from astropy.nddata import CCDData
from astropy.stats import gaussian_fwhm_to_sigma
from astropy.utils.decorators import deprecated
from astropy.wcs import WCS
from specutils import Spectrum
from specreduce.calibration_data import load_pypeit_calibration_lines
__all__ = [
"SynthImage",
"make_2d_trace_image",
"make_2d_arc_image",
"make_2d_spec_image",
]
_ALLOWED_TILT = (
models.Legendre1D,
models.Chebyshev1D,
models.Polynomial1D,
models.Hermite1D,
)
@dataclass(frozen=True)
class _RenderContext:
"""Shared geometry passed to every layer's ``render`` method."""
nx: int
ny: int
xx: np.ndarray
yy: np.ndarray
wcs: WCS | None
disp_axis: int
@dataclass(frozen=True)
class BackgroundLayer:
"""A constant additive background level in counts."""
level: float
def render(self, ctx: _RenderContext) -> np.ndarray:
return np.full((ctx.ny, ctx.nx), float(self.level))
[docs]
class SynthImage:
"""
Immutable, composable builder for synthetic 2D spectroscopic images.
Build an image by chaining ``add_*`` methods, then render it with one of the
``to_*`` terminal methods. Each ``add_*`` returns a *new* ``SynthImage``; the
original is never mutated, so a base configuration can be safely branched.
Parameters
----------
nx
Size of the image along the X (dispersion) axis.
ny
Size of the image along the Y (spatial) axis.
wcs
Optional 2D WCS with a single spectral axis. If not provided and arc
layers are present, a linear ``WAVE``/``PIXEL`` WCS is built from
``extent`` and ``wave_unit``.
extent
Beginning and end wavelengths used to build a default WCS when ``wcs``
is not supplied and arc layers are present.
wave_unit
Wavelength unit for the default WCS.
seed
Seed for the random number generator used by the noise layers. If
``None``, noise is non-deterministic.
Examples
--------
Build a traced continuum source with background and noise, then render it::
from astropy.modeling import models
from specreduce.utils.synth_data import SynthImage
image = (
SynthImage(nx=1024, ny=400, seed=42)
.add_background(5)
.add_source(profile=models.Moffat1D(amplitude=20, alpha=0.1))
.add_poisson_noise()
.add_read_noise(3)
.to_ccddata()
)
See the :ref:`synth_data` guide for more examples, including tilted arc
lines and modeling a non-linear dispersion relation.
"""
def __init__(
self,
nx: int = 3000,
ny: int = 1000,
wcs: WCS | None = None,
extent=(3500, 7000),
wave_unit: u.Unit = u.Angstrom,
seed: int | None = None,
):
self.nx = nx
self.ny = ny
self._wcs = wcs
self._extent = extent
self._wave_unit = wave_unit
self._seed = seed
self._layers = ()
self._poisson = False
self._read_noise = None
def _clone(self, **changes) -> "SynthImage":
new = SynthImage.__new__(SynthImage)
new.nx = self.nx
new.ny = self.ny
new._wcs = self._wcs
new._extent = self._extent
new._wave_unit = self._wave_unit
new._seed = self._seed
new._layers = self._layers
new._poisson = self._poisson
new._read_noise = self._read_noise
for key, value in changes.items():
setattr(new, key, value)
return new
[docs]
def add_background(self, level: float) -> "SynthImage":
"""Add a constant background level (counts)."""
return self._clone(_layers=self._layers + (BackgroundLayer(level),))
[docs]
def add_source(
self,
profile: Model = None,
trace_center: float | None = None,
trace_order: int = 3,
trace_coeffs: dict | None = None,
spectrum: Spectrum = None,
) -> "SynthImage":
"""
Add a source with a Chebyshev-traced spatial profile.
Parameters
----------
profile
Astropy model describing the cross-dispersion spatial profile of the
source. Defaults to ``Moffat1D(amplitude=10, alpha=0.1)``.
trace_center
Central cross-dispersion position of the trace. Defaults to the
middle of the image.
trace_order
Polynomial order of the Chebyshev trace.
trace_coeffs
Coefficients of the Chebyshev trace, e.g. ``{"c0": 0, "c1": 50}``.
spectrum
Optional 1D `~specutils.Spectrum` describing the dispersion-axis flux
of the source. Its flux is resampled onto the image wavelength grid
(requiring a resolvable WCS) and normalized so its peak within the
image extent is one. Wavelengths outside the spectrum's range are set
to zero. The normalized flux multiplies the spatial profile column by
column, so the source amplitude varies with wavelength. When ``None``
(default) the source has a flat continuum.
"""
if profile is None:
profile = models.Moffat1D(amplitude=10, alpha=0.1)
layer = SourceLayer(profile, trace_center, trace_order, trace_coeffs, spectrum)
return self._clone(_layers=self._layers + (layer,))
[docs]
def add_arcs(
self,
linelists=("HeI",),
line_fwhm: float = 5.0,
amplitude_scale: float = 1.0,
wave_air: bool = False,
tilt_func: Model = None,
) -> "SynthImage":
"""Add emission lines from one or more pypeit calibration line lists."""
if tilt_func is None:
tilt_func = models.Legendre1D(degree=0)
if isinstance(linelists, str):
linelists = (linelists,)
layer = ArcLayer(
tuple(linelists), line_fwhm, amplitude_scale, wave_air, tilt_func
)
return self._clone(_layers=self._layers + (layer,))
[docs]
def add_skylines(self, linelists="OH_GMOS", **kwargs) -> "SynthImage":
"""Add night-sky airglow emission lines (OH lists), wrapping ``add_arcs``."""
return self.add_arcs(linelists, **kwargs)
[docs]
def add_poisson_noise(self) -> "SynthImage":
"""Apply Poisson noise to the rendered signal (requires photutils)."""
return self._clone(_poisson=True)
[docs]
def add_read_noise(self, sigma: float) -> "SynthImage":
"""Add Gaussian read noise of standard deviation ``sigma`` (counts)."""
return self._clone(_read_noise=sigma)
add_rdnoise = add_read_noise
def _resolve_wcs(self):
needs_wcs = any(
isinstance(layer, ArcLayer)
or (isinstance(layer, SourceLayer) and layer.spectrum is not None)
for layer in self._layers
)
if self._wcs is not None:
wcs = self._wcs
if wcs.spectral.naxis != 1:
raise ValueError("Provided WCS must have a spectral axis.")
if wcs.naxis != 2:
raise ValueError("WCS must have NAXIS=2 for a 2D image.")
elif needs_wcs:
if self._extent is None:
raise ValueError("Must specify either a wavelength extent or a WCS.")
if len(self._extent) != 2:
raise ValueError("Wavelength extent must be of length 2.")
if u.get_physical_type(self._wave_unit) != "length":
raise ValueError("Wavelength unit must be a length unit.")
wcs = WCS(naxis=2)
wcs.wcs.ctype[0] = "WAVE"
wcs.wcs.ctype[1] = "PIXEL"
wcs.wcs.cunit[0] = self._wave_unit
wcs.wcs.cunit[1] = u.pixel
wcs.wcs.crval[0] = self._extent[0]
wcs.wcs.cdelt[0] = (self._extent[1] - self._extent[0]) / self.nx
wcs.wcs.crval[1] = 0
wcs.wcs.cdelt[1] = 1
else:
wcs = None
if wcs is None:
disp_axis = 1
else:
is_spectral = [a["coordinate_type"] == "spectral" for a in wcs.get_axis_types()]
disp_axis = 0 if is_spectral[0] else 1
return wcs, disp_axis
def _render(self):
wcs, disp_axis = self._resolve_wcs()
x = np.arange(self.nx)
y = np.arange(self.ny)
xx, yy = np.meshgrid(x, y)
ctx = _RenderContext(self.nx, self.ny, xx, yy, wcs, disp_axis)
signal = np.zeros((self.ny, self.nx))
for layer in self._layers:
signal = signal + layer.render(ctx)
rng = np.random.default_rng(self._seed)
if self._poisson:
from photutils.datasets import apply_poisson_noise
signal = apply_poisson_noise(signal, seed=rng)
if self._read_noise is not None:
signal = signal + rng.normal(0.0, self._read_noise, size=signal.shape)
return signal, wcs
[docs]
def to_array(self) -> np.ndarray:
"""Render and return the image as a plain ``numpy.ndarray`` (counts)."""
return self._render()[0]
[docs]
def to_ccddata(self) -> CCDData:
"""Render and return the image as a `~astropy.nddata.CCDData`."""
data, wcs = self._render()
return CCDData(data, unit=u.count, wcs=wcs)
[docs]
def to_spectrum(self) -> Spectrum:
"""Render and return the image as a `~specutils.Spectrum`."""
data, wcs = self._render()
if wcs is not None:
return Spectrum(flux=data * u.count, wcs=wcs)
return Spectrum(flux=data * u.count, spectral_axis_index=data.ndim - 1)
@dataclass(frozen=True)
class SourceLayer:
"""
A source whose spatial profile follows a Chebyshev trace.
The dispersion axis is the X (column) axis, matching the historical
``make_2d_trace_image`` behaviour (the trace ignores any WCS). When a
``spectrum`` is supplied, its normalized flux modulates the profile along
the dispersion axis; otherwise the source has a flat continuum.
"""
profile: Model
trace_center: float | None = None
trace_order: int = 3
trace_coeffs: dict | None = None
spectrum: Spectrum | None = None
def render(self, ctx: _RenderContext) -> np.ndarray:
trace_center = ctx.ny / 2 if self.trace_center is None else self.trace_center
if self.trace_coeffs is None:
# Default to a curved trace, but only keep coefficients that fit
# within the requested order so that trace_order < 2 still works.
trace_coeffs = {
f"c{i}": c
for i, c in enumerate((0, 50, 100))
if i <= self.trace_order
}
else:
trace_coeffs = self.trace_coeffs
trace_mod = models.Chebyshev1D(degree=self.trace_order, **trace_coeffs)
trace = ctx.yy - trace_center + trace_mod(ctx.xx / ctx.nx)
image = self.profile(trace)
if self.spectrum is not None:
image = image * self._normalized_flux(ctx)[np.newaxis, :]
return image
def _normalized_flux(self, ctx: _RenderContext) -> np.ndarray:
"""
Resample the source spectrum onto the image's dispersion-axis wavelengths.
Returns a 1D array of length ``ctx.nx``: the spectrum's flux interpolated
onto the image wavelength grid, zero outside the spectrum's range, and
normalized so its peak within the image extent is one.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="No observer defined on WCS.*")
image_waves = ctx.wcs.spectral.pixel_to_world(np.arange(ctx.nx))
spec_waves = self.spectrum.spectral_axis.to(
image_waves.unit, equivalencies=u.spectral()
)
flux = np.asarray(self.spectrum.flux.value, dtype=float)
order = np.argsort(spec_waves.value)
resampled = np.interp(
image_waves.value,
spec_waves.value[order],
flux[order],
left=0.0,
right=0.0,
)
peak = resampled.max()
if peak > 0:
resampled = resampled / peak
return resampled
@dataclass(frozen=True)
class ArcLayer:
"""
Emission lines from one or more pypeit calibration line lists.
Requires a resolvable WCS (supplied on ``SynthImage`` or built from its
``extent``). ``tilt_func`` applies a cross-dispersion tilt to simulate
curved lines.
"""
linelists: tuple
line_fwhm: float = 5.0
amplitude_scale: float = 1.0
wave_air: bool = False
tilt_func: Model = field(default_factory=lambda: models.Legendre1D(degree=0))
def render(self, ctx: _RenderContext) -> np.ndarray:
xx, yy = ctx.xx, ctx.yy
if self.tilt_func is not None:
if not isinstance(self.tilt_func, _ALLOWED_TILT):
raise ValueError(
"The only tilt functions currently supported are 1D polynomials "
"from astropy.models."
)
if ctx.disp_axis == 0:
xx = xx + self.tilt_func((yy - ctx.ny / 2) / ctx.ny)
else:
yy = yy + self.tilt_func((xx - ctx.nx / 2) / ctx.nx)
z = np.zeros((ctx.ny, ctx.nx))
linelist = load_pypeit_calibration_lines(list(self.linelists), wave_air=self.wave_air)
if linelist is not None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="No observer defined on WCS.*")
line_disp_positions = ctx.wcs.spectral.world_to_pixel(linelist["wavelength"])
line_sigma = gaussian_fwhm_to_sigma * self.line_fwhm
for line_pos, ampl in zip(line_disp_positions, linelist["amplitude"]):
line_mod = models.Gaussian1D(
amplitude=ampl * self.amplitude_scale,
mean=line_pos,
stddev=line_sigma,
)
if ctx.disp_axis == 0:
z += line_mod(xx)
else:
z += line_mod(yy)
return z
[docs]
@deprecated("1.10", alternative="SynthImage")
def make_2d_trace_image(
nx: int = 3000,
ny: int = 1000,
background: int | float = 5,
trace_center: int | float | None = None,
trace_order: int = 3,
trace_coeffs: dict | None = None,
profile: Model = None,
add_noise: bool = True,
) -> CCDData:
"""
Deprecated. Use :class:`SynthImage` instead.
Equivalent to ``SynthImage(nx, ny).add_background(background)
.add_source(...)`` followed by ``.add_poisson_noise()`` when ``add_noise``,
then ``.to_ccddata()``.
"""
if profile is None:
profile = models.Moffat1D(amplitude=10, alpha=0.1)
img = (
SynthImage(nx=nx, ny=ny)
.add_background(background)
.add_source(
profile=profile,
trace_center=trace_center,
trace_order=trace_order,
trace_coeffs=trace_coeffs,
)
)
if add_noise:
img = img.add_poisson_noise()
return img.to_ccddata()
[docs]
@deprecated("1.10", alternative="SynthImage")
def make_2d_arc_image(
nx: int = 3000,
ny: int = 1000,
wcs: WCS | None = None,
extent=(3500, 7000),
wave_unit: u.Unit = u.Angstrom,
wave_air: bool = False,
background: int | float = 5,
line_fwhm: float = 5.0,
linelists=("HeI",),
amplitude_scale: float = 1.0,
tilt_func: Model = None,
add_noise: bool = True,
) -> CCDData:
"""Deprecated. Use :class:`SynthImage` with ``.add_arcs(...)`` instead."""
if tilt_func is None:
tilt_func = models.Legendre1D(degree=0)
img = (
SynthImage(nx=nx, ny=ny, wcs=wcs, extent=extent, wave_unit=wave_unit)
.add_background(background)
.add_arcs(
linelists=linelists,
line_fwhm=line_fwhm,
amplitude_scale=amplitude_scale,
wave_air=wave_air,
tilt_func=tilt_func,
)
)
if add_noise:
img = img.add_poisson_noise()
return img.to_ccddata()
[docs]
@deprecated("1.10", alternative="SynthImage")
def make_2d_spec_image(
nx: int = 3000,
ny: int = 1000,
wcs: WCS | None = None,
extent=(6500, 9500),
wave_unit: u.Unit = u.Angstrom,
wave_air: bool = False,
background: int | float = 5,
line_fwhm: float = 5.0,
linelists=("OH_GMOS",),
amplitude_scale: float = 1.0,
tilt_func: Model = None,
trace_center: int | float | None = None,
trace_order: int = 3,
trace_coeffs: dict | None = None,
source_profile: Model = None,
add_noise: bool = True,
) -> CCDData:
"""Deprecated. Use :class:`SynthImage` with ``.add_arcs(...)`` and
``.add_source(...)`` instead."""
if tilt_func is None:
tilt_func = models.Legendre1D(degree=0)
if source_profile is None:
source_profile = models.Moffat1D(amplitude=10, alpha=0.1)
img = (
SynthImage(nx=nx, ny=ny, wcs=wcs, extent=extent, wave_unit=wave_unit)
.add_background(background)
.add_arcs(
linelists=linelists,
line_fwhm=line_fwhm,
amplitude_scale=amplitude_scale,
wave_air=wave_air,
tilt_func=tilt_func,
)
.add_source(
profile=source_profile,
trace_center=trace_center,
trace_order=trace_order,
trace_coeffs=trace_coeffs,
)
)
if add_noise:
img = img.add_poisson_noise()
return img.to_ccddata()