Source code for skyscapes.scene.star
"""Star models for the scene hierarchy.
``AbstractStar`` declares a ``Ms_kg`` / ``dist_pc`` pair and the
``spec_flux_density`` hook. ``FlatStar`` is a flat-spectrum stand-in
useful for ETC runs. ``Star`` wraps an ``interpax.Interpolator2D``
over (wavelength, time) built from Jansky flux data, matching the legacy
``skyscapes._legacy.Star`` semantics.
Note: ``from __future__ import annotations`` is deliberately NOT used
here -- it stringifies annotations, which breaks Equinox's metaclass
handling of ``AbstractVar`` type parameters.
"""
from abc import abstractmethod
import equinox as eqx
import interpax
import jax
import jax.numpy as jnp
from equinox import AbstractVar
from hwoutils.conversions import jy_to_photons_per_nm_per_m2
from jaxtyping import Array
[docs]
class AbstractStar(eqx.Module):
"""Abstract stellar source.
Attributes:
Ms_kg: Stellar mass in kilograms.
dist_pc: Distance to the star in parsecs.
"""
Ms_kg: AbstractVar[float]
dist_pc: AbstractVar[float]
[docs]
@abstractmethod
def spec_flux_density(
self,
wavelength_nm: Array,
time_jd: Array,
) -> Array:
"""Return spectral flux density in ph/s/m^2/nm."""
[docs]
class FlatStar(AbstractStar):
"""Flat-spectrum star -- constant flux independent of wavelength or time."""
Ms_kg: float
dist_pc: float
flux_phot_per_nm_m2: float
[docs]
def spec_flux_density(
self,
wavelength_nm: Array,
time_jd: Array,
) -> Array:
"""Constant flux, broadcast to wavelength_nm's shape.
``time_jd`` is part of the AbstractStar interface but ignored here.
"""
wl = jnp.asarray(wavelength_nm)
return jnp.full_like(wl, self.flux_phot_per_nm_m2, dtype=wl.dtype)
[docs]
def __repr__(self) -> str:
"""Compact one-line summary of mass, distance, and flux."""
m_solar = self.Ms_kg / 1.988409870698051e30
return (
f"FlatStar(M={m_solar:.3f} Msun, dist={self.dist_pc:.3g} pc, "
f"flux={self.flux_phot_per_nm_m2:.3g} ph/s/m^2/nm)"
)
[docs]
class Star(AbstractStar):
"""Time- and wavelength-dependent star backed by an interpax 2D spline."""
Ms_kg: float
dist_pc: float
ra_deg: float
dec_deg: float
diameter_arcsec: float
luminosity_lsun: float
_wavelengths_nm: Array
_times_jd: Array
_flux_density_phot: Array
_flux_interp: interpax.Interpolator2D
def __init__(
self,
*,
Ms_kg: float,
dist_pc: float,
wavelengths_nm: Array,
times_jd: Array,
flux_density_jy: Array,
ra_deg: float = 0.0,
dec_deg: float = 0.0,
diameter_arcsec: float = 0.0,
luminosity_lsun: float = 1.0,
):
"""Store stellar metadata and pre-build the flux-density interpolant."""
self.Ms_kg = Ms_kg
self.dist_pc = dist_pc
self.ra_deg = ra_deg
self.dec_deg = dec_deg
self.diameter_arcsec = diameter_arcsec
self.luminosity_lsun = luminosity_lsun
self._wavelengths_nm = wavelengths_nm
self._times_jd = times_jd
self._flux_density_phot = jax.vmap(
jy_to_photons_per_nm_per_m2, in_axes=(1, None), out_axes=1
)(flux_density_jy, wavelengths_nm)
self._flux_interp = interpax.Interpolator2D(
wavelengths_nm, times_jd, self._flux_density_phot, method="cubic"
)
[docs]
def spec_flux_density(
self,
wavelength_nm: Array,
time_jd: Array,
) -> Array:
"""Scalar or array spectral flux density [ph/s/m^2/nm]."""
return self._flux_interp(wavelength_nm, time_jd)
[docs]
def __repr__(self) -> str:
"""One-line summary of metadata + wavelength/time grid extent."""
m_solar = self.Ms_kg / 1.988409870698051e30
wl_min = float(jnp.min(self._wavelengths_nm))
wl_max = float(jnp.max(self._wavelengths_nm))
n_t = int(self._times_jd.shape[0])
return (
f"Star(M={m_solar:.3f} Msun, "
f"dist={self.dist_pc:.3g} pc, "
f"L={self.luminosity_lsun:.3g} Lsun, "
f"diam={self.diameter_arcsec:.3g} arcsec, "
f"RA/Dec=({self.ra_deg:.4f}, {self.dec_deg:.4f}) deg, "
f"wl={wl_min:.0f}-{wl_max:.0f} nm, n_times={n_t})"
)