Source code for skyscapes.physical_model.cached

"""Pre-computed reflectivity physical model.

Holds a fixed ``(K, n_nu)`` plane-parallel-reflectivity array that was
produced by running some other :class:`AbstractPhysicalModel`'s heavy
radiative transfer once. ``contrast`` becomes a cubic interpolation +
Lambert phase + ``(Rp/d)^2`` lookup -- microseconds per call vs.
seconds for an ExoJAX 2-stream RT.

Use cases:
  - Coronagraphoto simulations where the physics is fixed across
    many evaluations (orbits, wavelengths, time samples).
  - ETC studies where you want fast forward-modeling against a
    canonical Earth-like atmosphere.

Not for HMC retrievals where parameters vary -- the underlying
expensive model is the right tool there.

The cache file format is a NumPy ``.npz`` with two keyword arrays
(``reflectivity``, ``nu_grid``) and a small JSON-style header tag for
version-tracking.
"""

from __future__ import annotations

from pathlib import Path

import equinox as eqx
import interpax
import jax
import jax.numpy as jnp
import numpy as np
from hwoutils.constants import Rearth2AU
from hwoutils.conversions import (
    nm_to_wavenumber_cm,
    spherical_to_geometric_albedo,
)
from jaxtyping import Array

from .._repr import indent
from .base import AbstractPhysicalModel
from .lambertian import _lambert_phase

# Bump when the cache file layout changes in a way that invalidates
# previously-saved caches.
CACHE_FORMAT_VERSION = 2


[docs] class PrecomputedPhysicalModel(AbstractPhysicalModel): """Physical model whose reflectivity spectrum has been pre-computed. Attributes: reflectivity: Per-planet plane-parallel reflectivity ``R(nu)`` on the wavenumber grid, shape ``(K, n_nu)``. Already includes all the physics of the original physical model (absorption, Rayleigh, clouds, surface). nu_grid: Wavenumber grid [cm^-1], shape ``(n_nu,)``. n_nu: Length of ``nu_grid`` (static for JIT). """ reflectivity: Array nu_grid: Array n_nu: int = eqx.field(static=True)
[docs] @classmethod def from_physical_model( cls, model: AbstractPhysicalModel ) -> PrecomputedPhysicalModel: """Pre-compute the reflectivity from an existing physical model. Calls the model's internal ``_reflectivity_all_planets`` once and packages the result. Requires the model to expose ``_reflectivity_all_planets``, ``nu_grid``, and ``n_nu`` -- presently :class:`ExoJaxPhysicalModel` is the supported source. """ try: reflectivity = model._reflectivity_all_planets() except AttributeError as e: # pragma: no cover raise TypeError( f"{type(model).__name__} does not expose " "_reflectivity_all_planets; PrecomputedPhysicalModel only " "supports physical models that compute a per-planet R(nu) " "spectrum (currently ExoJaxPhysicalModel)." ) from e return cls( reflectivity=reflectivity, nu_grid=model.nu_grid, n_nu=model.n_nu, )
[docs] @classmethod def load(cls, path: str | Path) -> PrecomputedPhysicalModel: """Load a previously-saved cache file. Args: path: Path to a ``.npz`` file produced by :meth:`save`. Returns: A ``PrecomputedPhysicalModel`` ready to evaluate. Raises: ValueError: if the file's ``cache_format_version`` differs from the current code's version (would silently produce wrong spectra otherwise). """ path = Path(path) with np.load(path) as data: version = int(data["cache_format_version"]) if version != CACHE_FORMAT_VERSION: raise ValueError( f"Cache file {path} has format version {version}; " f"current code version is {CACHE_FORMAT_VERSION}. " f"Delete the file to rebuild." ) reflectivity = jnp.asarray(data["reflectivity"]) nu_grid = jnp.asarray(data["nu_grid"]) return cls( reflectivity=reflectivity, nu_grid=nu_grid, n_nu=int(nu_grid.shape[0]), )
[docs] def save(self, path: str | Path) -> None: """Save the cached reflectivity to a ``.npz`` file. Idempotent and safe to call from JAX-traced contexts (the save itself is plain NumPy -- callers should not call this inside a JIT region, but the data is just regular arrays). """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) np.savez( path, cache_format_version=np.asarray(CACHE_FORMAT_VERSION), reflectivity=np.asarray(self.reflectivity), nu_grid=np.asarray(self.nu_grid), )
[docs] def contrast( self, phase_angle_rad: Array, dist_AU: Array, wavelength_nm: Array, Rp_Rearth: Array, ) -> Array: """Per-planet, per-time geometric-albedo contrast at one wavelength. Args: phase_angle_rad: Star-planet-observer phase angle, shape ``(K, T)``. dist_AU: Star-planet distance [AU], shape ``(K, T)``. wavelength_nm: Scalar wavelength [nm]. Rp_Rearth: Planet radius [Earth radii], shape ``(K,)``. Returns: Contrast = ``A_g(lambda) * Lambert_phase(beta) * (Rp/d)^2``, shape ``(K, T)``. The cached array stores the underlying model's plane-parallel (spherical) reflectivity; we convert to geometric albedo via the Lambertian-sphere factor 2/3 (Seager 2010, eq 3.36) at call time -- same convention as :class:`ExoJaxPhysicalModel.contrast`. """ target_nu = nm_to_wavenumber_cm(wavelength_nm) def interp_one(spectrum): return interpax.interp1d( target_nu, self.nu_grid, spectrum, method="cubic", extrap=True, ) R_at_wl = jax.vmap(interp_one)(self.reflectivity) # (K,) Rp_AU = (Rp_Rearth * Rearth2AU)[:, None] Ag = spherical_to_geometric_albedo(R_at_wl)[:, None] phase = _lambert_phase(phase_angle_rad) return Ag * phase * (Rp_AU / dist_AU) ** 2
[docs] def contrast_cube( self, phase_angle_rad: Array, dist_AU: Array, wavelengths_nm: Array, Rp_Rearth: Array, ) -> Array: """Per-planet, per-time geometric-albedo contrast across wavelengths. Returns shape ``(W, K, T)``. Avoids per-wavelength recomputation by vectorising the interpolation; applies the Lambertian-sphere spherical-to-geometric conversion the same way as :meth:`contrast`. """ target_nu = nm_to_wavenumber_cm(wavelengths_nm) def interp_one_planet(spectrum): return interpax.interp1d( target_nu, self.nu_grid, spectrum, method="cubic", extrap=True, ) R_at_wls = spherical_to_geometric_albedo( jax.vmap(interp_one_planet)(self.reflectivity) ) # (K, W) Rp_AU = (Rp_Rearth * Rearth2AU)[:, None] phase = _lambert_phase(phase_angle_rad) geom = phase * (Rp_AU / dist_AU) ** 2 # (K, T) return jnp.einsum("kw,kt->wkt", R_at_wls, geom)
[docs] def __repr__(self) -> str: """Compact summary of the cached spectrum.""" K = int(self.reflectivity.shape[0]) wl_min = float(1.0e7 / self.nu_grid.max()) wl_max = float(1.0e7 / self.nu_grid.min()) peak_mean = float(jnp.mean(self.reflectivity.max(axis=-1))) lines = [ f"PrecomputedPhysicalModel(K={K}, n_nu={self.n_nu})", ( f" Wavelength: {wl_min:.0f}-{wl_max:.0f} nm, " f"mean peak R = {peak_mean:.3f}" ), indent( "(no RT cost; reflectivity sampled at call time)", prefix=" ", ), ] return "\n".join(lines)