import matplotlib.pyplot as plt
import numpy as np
import astropy.units as u
from astropy.modeling import models
from specutils import Spectrum
from specreduce.utils.synth_data import SynthImage

# a rising continuum with two emission lines
wave = np.linspace(4000, 7000, 1000) * u.Angstrom
continuum = 1.0 + 0.3 * (wave.value - 4000) / 3000
emission = (
    2.0 * np.exp(-0.5 * ((wave.value - 5000) / 20) ** 2)
    + 1.5 * np.exp(-0.5 * ((wave.value - 6300) / 30) ** 2)
)
spectrum = Spectrum(flux=(continuum + emission) * u.count, spectral_axis=wave)

image = (
    SynthImage(nx=1024, ny=300, extent=(4000, 7000), seed=42)
    .add_background(5)
    .add_source(profile=models.Moffat1D(amplitude=50, alpha=0.1), spectrum=spectrum)
    .add_poisson_noise()
    .to_ccddata()
)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5))
ax1.plot(wave, spectrum.flux)
ax1.set_xlabel("Wavelength (Angstrom)")
ax1.set_ylabel("Input flux")
ax2.imshow(image, origin="lower", aspect="auto")
ax2.set_xlabel("Dispersion axis (pix)")
ax2.set_ylabel("Cross-disp. axis (pix)")
fig.tight_layout()