Physical models#
skyscapes.physical_model provides the planet-to-star contrast
model attached to every Planet. A physical model returns, given a
phase angle, star-planet distance, wavelength, and planet radius, the
dimensionless contrast \(f_p / f_*\) that scales the host star’s
spectrum into the planet’s emitted flux. The submodule covers
atmospheric reflection, surface reflection, thermal emission, and
joint planet + atmosphere models.
The hierarchy#
flowchart TB
APM[AbstractPhysicalModel]
LPM[LambertianPhysicalModel]
GPM[GridPhysicalModel]
EXO[ExoJaxPhysicalModel]
PPM[PrecomputedPhysicalModel]
APM --> LPM
APM --> GPM
APM --> EXO
APM --> PPM
All concrete classes implement a common contrast method that
returns the per-planet flux-ratio contrast at the given phase angle,
distance, wavelength, and planet radius (shape (K, T) for K
planets and T time steps). The differences are in physics fidelity
and computational cost.
Each physical model carries per-planet state as JAX arrays of length
K. Planet radius and mass are passed in at call time from the host
Planet; the model does not store them. Heterogeneous planets
(different physical-model classes) compose into a System.planets
tuple, with one Planet per physical-model type.
Variants and when to use them#
LambertianPhysicalModel#
Lambert phase function with constant geometric albedo. No wavelength dependence; analytic and fast.
import jax.numpy as jnp
from skyscapes.physical_model import LambertianPhysicalModel
model = LambertianPhysicalModel(Ag=jnp.array([0.3]))
Use for: first-pass forward modeling, sandbox / debugging, closed-form sanity checks. Most-used variant in early-stage work.
GridPhysicalModel#
Wraps an externally computed contrast cube as a 2-D interpax spline
over (wavelength, phase_angle) per planet. Differentiable through
the interpolation. This is the model from_exovista populates.
import jax.numpy as jnp
from skyscapes.physical_model import GridPhysicalModel
model = GridPhysicalModel(
wavelengths_nm=jnp.linspace(400, 1000, 60), # (n_wl,)
phase_angle_deg=jnp.linspace(0, 180, 19), # (n_phase,)
contrast_grid=..., # (K, n_wl, n_phase)
)
Use for: tabulated atmospheres from grid-based RT codes;
ExoVista-loaded systems; comparison studies; when you have a
(wl, phase) -> contrast LUT and want it integrated with the rest of
the pipeline.
ExoJaxPhysicalModel#
Full ExoJax 2-stream radiative transfer. Computes reflectivity from
first principles given molecular mass mixing ratios (H2O, CO2, CH4,
O2, O3, …). The full constructor exposes a lot of knobs; the
ergonomic entry point is ExoJaxPhysicalModel.from_default_setup:
import jax.numpy as jnp
from skyscapes.physical_model import ExoJaxPhysicalModel
model = ExoJaxPhysicalModel.from_default_setup(
log_mmrs={"H2O": jnp.array([-4.0]), "CO2": jnp.array([-3.5])},
T_eq_K=jnp.array([288.0]),
T_alpha=jnp.array([0.0]),
log_surface_albedo=jnp.array([-1.0]),
log_gravity_cgs=jnp.array([2.99]),
wavelength_min_nm=400.0,
wavelength_max_nm=1000.0,
)
Use for: research-grade physics; retrievals on real atmospheric composition; joint atmosphere + orbit fitting workflows.
The first call triggers a HITRAN download and a 1-3 minute setup
(see PrecomputedPhysicalModel below to avoid this cost in inner
loops).
PrecomputedPhysicalModel#
Wraps an ExoJax model’s full per-planet reflectivity grid as a fast
lookup. Once you’ve built the underlying ExoJaxPhysicalModel,
freeze it for inner-loop reuse:
from skyscapes.physical_model import ExoJaxPhysicalModel, PrecomputedPhysicalModel
# One-shot RT computation
ej = ExoJaxPhysicalModel.from_default_setup(...)
cached = PrecomputedPhysicalModel.from_physical_model(ej)
# Use as a drop-in physical model
contrast = cached.contrast(
phase_angle_rad=jnp.array([[jnp.pi / 3]]), # (K, T)
dist_AU=jnp.array([[1.0]]), # (K, T)
wavelength_nm=550.0,
Rp_Rearth=jnp.array([1.0]), # (K,)
)
from_physical_model currently supports ExoJaxPhysicalModel (the
classmethod calls the model’s internal _reflectivity_all_planets).
The cache can also be persisted to disk via .save(path) and reloaded
via PrecomputedPhysicalModel.load(path).
Use for: any time you need the same physical model in a hot inner loop (forward simulation, MCMC sampling, multi-target evaluation).
Composition with Planet#
A Planet carries its intrinsic params (Rp_Rearth, Mp_Mearth)
plus an orbit and a physical model:
import jax.numpy as jnp
from orbix.system.orbit import KeplerianOrbit
from skyscapes.scene import Planet
from skyscapes.physical_model import LambertianPhysicalModel
orbit = KeplerianOrbit(
a_AU=jnp.array([1.0]),
e=jnp.array([0.0]),
W_rad=jnp.array([0.0]),
i_rad=jnp.array([jnp.pi / 3]),
w_rad=jnp.array([0.0]),
M0_rad=jnp.array([0.0]),
t0_d=jnp.array([0.0]),
)
physical_model = LambertianPhysicalModel(Ag=jnp.array([0.3]))
planet = Planet(
Rp_Rearth=jnp.array([1.0]),
Mp_Mearth=jnp.array([1.0]),
orbit=orbit,
physical_model=physical_model,
)
Heterogeneous physical models compose at the System.planets level
– one Planet per model type, each batching its own K planets:
lambertian_planet = Planet(
Rp_Rearth=..., Mp_Mearth=...,
orbit=lambertian_orbit,
physical_model=LambertianPhysicalModel(...),
)
cached_planet = Planet(
Rp_Rearth=..., Mp_Mearth=...,
orbit=cached_orbit,
physical_model=PrecomputedPhysicalModel.from_physical_model(...),
)
system = System(star=star, planets=(lambertian_planet, cached_planet))
Performance budget#
Approximate cost per planet per wavelength on GPU:
Physical model |
Per-call cost |
|---|---|
|
~0.01 ms |
|
~0.05 ms (interp lookup) |
|
~0.15 ms (1-D interp) |
|
~13 ms (full 2-stream RT) |
PrecomputedPhysicalModel is the right tool for inner loops where you
need ExoJax-quality physics but cannot afford 13 ms per evaluation.
Pre-compute once at simulation setup, then use the cached version for
all subsequent calls.
Wavelength dependence#
contrast accepts a scalar wavelength_nm or a 1-D array of
wavelengths. For an IFS-style spectral cube, pass the wavelength axis
directly:
import jax
import jax.numpy as jnp
wavelengths = jnp.linspace(400, 1000, 100)
phase = jnp.array([[jnp.pi / 3]]) # (K=1, T=1)
dist = jnp.array([[1.0]]) # (K=1, T=1)
Rp = jnp.array([1.0]) # (K=1,)
contrast_spectrum = jax.vmap(
lambda wl: model.contrast(phase, dist, wl, Rp),
)(wavelengths)
For ExoJaxPhysicalModel specifically, the RT routine is vectorised
internally, so contrast_cube runs at ~13 ms regardless of wavelength
count (up to ~100 wl bins).
See also#
Source models –
Planetand the rest of the scene hierarchy