Source code for skyscapes.scene.system

"""scene.System -- the top-level PyTree container.

Holds a star, a heterogeneous tuple of planets, an optional disk, and the
Kepler-solver callable (static so JIT doesn't re-trace). The tuple makes
the system variadic: a 1-planet system and an 8-planet system have
compatible shapes as long as per-planet arrays broadcast correctly.
"""

from __future__ import annotations

from collections.abc import Callable

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array

from .._repr import indent
from ..disk import AbstractDisk
from .planet import Planet
from .star import AbstractStar


[docs] class System(eqx.Module): """Astrophysical scene: star + tuple of planets + optional disk. Attributes: star: Host star (``AbstractStar``). planets: Variable-length tuple of ``Planet``. trig_solver: Scalar Kepler-trig solver (static; see ``orbix.kepler.shortcuts.grid.get_grid_solver``). Required -- callers must provide a built solver, not None. disk: Optional extended-source disk (``AbstractDisk | None``). midplane_inc_deg: System midplane inclination [deg] in the barycentric -> sky frame. Default 0.0 means "midplane = sky" and is intentionally indistinguishable from a real face-on system at this inclination; this ambiguity is acceptable because the field is diagnostic-only after load (no runtime hot path consults it). Populated by ``io.from_exovista`` from the FITS star header. After load, frame rotation has already been baked into each ``Planet``'s orbital elements. midplane_pa_deg: System midplane position angle [deg]. Same semantics as ``midplane_inc_deg``. """ star: AbstractStar planets: tuple[Planet, ...] trig_solver: Callable = eqx.field(static=True) disk: AbstractDisk | None = None midplane_inc_deg: float = 0.0 midplane_pa_deg: float = 0.0 @property def n_planets(self) -> int: """Total number of planets across all composed ``Planet`` modules.""" return sum(p.n_planets for p in self.planets)
[docs] def positions(self, t_jd: Array) -> Array: """Concatenated on-sky positions, shape ``(2, K_total, T)``.""" per_planet = [ p.position_arcsec(self.trig_solver, t_jd, star=self.star) for p in self.planets ] return jnp.concatenate(per_planet, axis=1)
[docs] def contrasts(self, wavelength_nm: Array, t_jd: Array) -> Array: """Per-planet contrast, shape ``(K_total, T)``.""" per_planet = [ p.contrast(self.trig_solver, wavelength_nm, t_jd, star=self.star) for p in self.planets ] return jnp.concatenate(per_planet, axis=0)
[docs] def planet_flux_densities(self, wavelength_nm: Array, t_jd: Array) -> Array: """Per-planet flux density [ph/s/m^2/nm], shape ``(K_total, T)``.""" per_planet = [ p.spec_flux_density(self.trig_solver, wavelength_nm, t_jd, star=self.star) for p in self.planets ] return jnp.concatenate(per_planet, axis=0)
[docs] def alpha_dMag(self, t_jd: Array) -> tuple[Array, Array]: """Per-planet projected separation + dMag, each shape ``(K_total, T)``.""" per_planet = [ p.alpha_dMag(self.trig_solver, t_jd, star=self.star) for p in self.planets ] alpha = jnp.concatenate([a for a, _ in per_planet], axis=0) dMag = jnp.concatenate([m for _, m in per_planet], axis=0) return alpha, dMag
[docs] def __repr__(self) -> str: """Tree-shaped summary: star + planets + disk + midplane geometry.""" n_modules = len(self.planets) n_total = self.n_planets lines = [ ( f"System(n_planet_modules={n_modules}, " f"K_total={n_total}, " f"midplane: i={self.midplane_inc_deg:.2f} deg, " f"PA={self.midplane_pa_deg:.2f} deg)" ), indent("star: " + repr(self.star)), ] for i, p in enumerate(self.planets): lines.append(indent(f"planet[{i}]:")) lines.append(indent(repr(p), prefix=" ")) if self.disk is not None: lines.append(indent("disk: " + repr(self.disk))) else: lines.append(" disk: None") return "\n".join(lines)