Source code for skyscapes.physical_model.exojax.physical_model

"""ExoJaxPhysicalModel: composition-based reflected-light planet model.

Each piece of physics is a swappable component; each molecule is a
self-contained :class:`MolecularSpecies` record. To explore different
atmospheres, swap a component or a species and rebuild.

What's currently included:

  - **Species** (:class:`MolecularSpecies` tuple): variable-length set
    of molecules, each carrying its own per-planet log mass-mixing
    ratio + the shared static data (molar mass, opacity engine,
    Rayleigh cross-section). The biosignature five (H2O, CO2, CH4, O2,
    O3) is preregistered, plus CO, N2O, SO2 for richer atmospheres
    (e.g. the Star+ "Earth through time" set).
  - **Bulk gas** (:class:`BulkGasResidual`): implicit residual gas
    (N2 default, H2/He available) filling 1 - sum(tracked mmrs).
    Contributes only Rayleigh scattering.
  - **Absorption** (:class:`Absorption` component): iterates species,
    sums line-list / cross-section contributions.
  - **Scattering** (:class:`RayleighScattering` or :class:`NullScattering`):
    Rayleigh from tracked species + bulk.
  - **Clouds** (:class:`GrayCloud` / :class:`NoCloud`): single-layer
    gray cloud with softmax-Gaussian vertical distribution.
  - **Surface** (:class:`WavelengthDependentSurface` / :class:`FlatSurface`):
    per-planet albedo scalar x wavelength-dependent reflectivity profile.
  - **T-P profile** (:class:`PowerLawTPProfile`): ``T(P) = T_eq * P^alpha``.

Construction patterns:

- **One-shot**: ``ExoJaxPhysicalModel.from_default_setup(...)`` accepts
  ``log_mmrs`` as a dict mapping molecule name to ``(K,)`` array and
  builds engines + default components in one call.
- **Build-once, fit-many** (retrieval-friendly): call
  :func:`build_exojax_engines(molecules=...)` once for the heavy
  per-molecule construction (HITRAN downloads + opacity precompute),
  then assemble many atmospheres per sample.

References:
    Kawahara et al. 2022, ApJS 258, 31 (ExoJAX Paper I)
"""

from __future__ import annotations

import hashlib
from pathlib import Path
from typing import Any

import equinox as eqx
import interpax
import jax
import jax.numpy as jnp
import numpy as np
from exojax.rt import ArtReflectPure
from exojax.utils.grids import wavenumber_grid
from hwoutils.constants import Rearth2AU
from hwoutils.conversions import (
    nm_to_wavenumber_cm,
    spherical_to_geometric_albedo,
)
from jaxtyping import Array

from ..._repr import indent
from ..base import AbstractPhysicalModel
from ..cached import PrecomputedPhysicalModel
from ..lambertian import _lambert_phase
from .components import (
    Absorption,
    AbstractAbsorption,
    AbstractClouds,
    AbstractMmrProfile,
    AbstractScattering,
    AbstractSurface,
    AbstractTPProfile,
    BulkGasResidual,
    ConstantMmr,
    GrayCloud,
    MolecularSpecies,
    PowerLawTPProfile,
    RayleighScattering,
    WavelengthDependentSurface,
    build_bulk_prebuilt,
    build_species_prebuilt,
)

# Default molecule set: the five-molecule biosignature set used in
# most HWO reflected-light studies. Replace via ``molecules=...`` to
# build engines for a different set (e.g. the 8-molecule Star+ "Earth
# through time" composition).
DEFAULT_MOLECULES: tuple[str, ...] = ("H2O", "CO2", "CH4", "O2", "O3")


[docs] def build_exojax_engines( *, molecules: tuple[str, ...] = DEFAULT_MOLECULES, bulk_gas: str | None = "N2", wavelength_min_nm: float = 400.0, wavelength_max_nm: float = 1000.0, n_wavenumbers: int = 2000, n_layers: int = 100, pressure_top_bar: float = 1.0e-5, pressure_btm_bar: float = 1.0e0, databases_dir: str | None = None, crit: float = 0.0, ) -> dict: """Build the heavy shared engines and pre-built per-species data. Each molecule's opacity engine (line-list LSD tables or cross- section interpolant) is built once here and reused across many atmosphere instantiations that vary only the per-planet ``log_mmrs``. The retrieval-friendly construction pattern is:: engines = build_exojax_engines(molecules=("H2O", "CO", "O2")) def make_atmosphere(per_planet_log_mmrs): return ExoJaxPhysicalModel.from_default_setup( log_mmrs=per_planet_log_mmrs, ..., **engines ) Args: molecules: Tuple of molecule names to include. Must all appear in :data:`MOLECULE_RECIPES`. Default is the 5-molecule biosignature set. bulk_gas: Implicit residual gas name (``"N2"``, ``"H2"``, ``"He"``), or ``None`` to disable the bulk-gas Rayleigh contribution. Default ``"N2"`` for terrestrial atmospheres. wavelength_min_nm: Short-wavelength end of the spectral range [nm]. wavelength_max_nm: Long-wavelength end of the spectral range [nm]. n_wavenumbers: Number of points in the wavenumber grid. n_layers: Number of layers in the plane-parallel RT solver. pressure_top_bar: Pressure at the top of the model [bar]. pressure_btm_bar: Pressure at the bottom of the model [bar]. databases_dir: ExoJAX line-list cache. crit: Line-strength cutoff [cm/(molecule.cm^-2)] passed to ``MdbHitran`` (default ``0.0`` = no filtering). Set ``crit=1e-26`` for ~25% speedup at HWO contrast levels. Returns: Dict ready to be ``**``-unpacked into :class:`ExoJaxPhysicalModel`. Keys: ``rt_engine``, ``nu_grid``, ``n_nu``, ``species_prebuilt`` (dict mapping name to ``{molmass, opa, rayleigh_xs}``), ``bulk_prebuilt`` (dict with ``{name, molmass, rayleigh_xs}`` or ``None``). """ if databases_dir is None: databases_dir = str(Path.home() / ".cache" / "skyscapes" / "exojax") Path(databases_dir).mkdir(parents=True, exist_ok=True) nu_max = nm_to_wavenumber_cm(wavelength_min_nm) nu_min = nm_to_wavenumber_cm(wavelength_max_nm) nu_grid, _, _ = wavenumber_grid( nu_min, nu_max, n_wavenumbers, xsmode="premodit", unit="cm-1" ) rt_engine = ArtReflectPure( nu_grid=nu_grid, pressure_top=pressure_top_bar, pressure_btm=pressure_btm_bar, nlayer=n_layers, ) species_prebuilt: dict[str, dict[str, Any]] = {} for name in molecules: molmass, opa, rayleigh_xs = build_species_prebuilt( name=name, nu_grid=nu_grid, nu_min=float(nu_min), nu_max=float(nu_max), databases_dir=databases_dir, crit=crit, ) species_prebuilt[name] = { "molmass": molmass, "opa": opa, "rayleigh_xs": rayleigh_xs, } if bulk_gas is None: bulk_prebuilt = None else: bulk_molmass, bulk_rayleigh_xs = build_bulk_prebuilt( name=bulk_gas, nu_grid=nu_grid ) bulk_prebuilt = { "name": bulk_gas, "molmass": bulk_molmass, "rayleigh_xs": bulk_rayleigh_xs, } return { "rt_engine": rt_engine, "nu_grid": nu_grid, "n_nu": int(nu_grid.shape[0]), "species_prebuilt": species_prebuilt, "bulk_prebuilt": bulk_prebuilt, }
[docs] def _assemble_species( log_mmrs: dict[str, Array | AbstractMmrProfile], species_prebuilt: dict[str, dict[str, Any]], ) -> tuple[MolecularSpecies, ...]: """Combine per-planet log_mmrs with prebuilt opacity data into species. Each dict entry can be either a per-planet ``(K,)`` array (treated as a constant-mmr profile, the back-compat path) or an :class:`AbstractMmrProfile` instance (use as-is for altitude- resolved profiles like O3 stratospheric peak or H2O cold trap). """ species = [] for name, value in log_mmrs.items(): if name not in species_prebuilt: raise KeyError( f"Species {name!r} not in prebuilt set " f"{sorted(species_prebuilt.keys())}. " f"Pass molecules=(...) to build_exojax_engines to include it." ) pre = species_prebuilt[name] if isinstance(value, AbstractMmrProfile): profile = value else: profile = ConstantMmr(log_mmr=value) species.append( MolecularSpecies( profile=profile, name=name, molmass=pre["molmass"], opa=pre["opa"], rayleigh_xs=pre["rayleigh_xs"], ) ) return tuple(species)
[docs] def _assemble_bulk( bulk_prebuilt: dict[str, Any] | None, ) -> BulkGasResidual | None: """Construct the bulk-gas residual from prebuilt data, or return None.""" if bulk_prebuilt is None: return None return BulkGasResidual( name=bulk_prebuilt["name"], molmass=bulk_prebuilt["molmass"], rayleigh_xs=bulk_prebuilt["rayleigh_xs"], )
# Bump when the set of from_default_setup kwargs changes meaning in a # way that would silently invalidate previously-saved caches. _CACHE_KEY_VERSION = 1
[docs] def _hash_value(value: Any) -> bytes: """Stable byte representation of an arbitrary kwarg value. Handles JAX/NumPy arrays, dicts (sorted keys), tuples/lists, :class:`equinox.Module` instances (via the PyTree protocol), and plain scalars. Used by :func:`_cache_key` to build a deterministic hash of ``from_default_setup_cached``'s arguments. """ if isinstance(value, (jnp.ndarray, np.ndarray)): arr = np.asarray(value) return arr.tobytes() + str(arr.shape).encode() + str(arr.dtype).encode() if isinstance(value, dict): return b"||".join( k.encode() + b":" + _hash_value(value[k]) for k in sorted(value.keys()) ) if isinstance(value, (tuple, list)): return b"||".join(_hash_value(item) for item in value) if isinstance(value, eqx.Module): leaves, treedef = jax.tree.flatten(value) pieces = [ type(value).__name__.encode(), repr(treedef).encode(), ] for leaf in leaves: pieces.append(_hash_value(leaf)) return b"||".join(pieces) if value is None: return b"None" return repr(value).encode()
[docs] def _cache_key(**kwargs: Any) -> str: """SHA256 hex digest of the cached factory's kwargs + version.""" h = hashlib.sha256() h.update(f"v{_CACHE_KEY_VERSION}".encode()) for name in sorted(kwargs.keys()): h.update(name.encode()) h.update(_hash_value(kwargs[name])) return h.hexdigest()[:16]
[docs] class ExoJaxPhysicalModel(AbstractPhysicalModel): """Composition-based reflected-light planet model over ExoJAX's 2-stream RT. Per-planet state (PyTree leaves, fittable): log_gravity_cgs: Log10 surface gravity [cm/s^2], shape ``(K,)``. species: Tuple of :class:`MolecularSpecies`. Each species owns its own ``log_mmr`` (per-planet, shape ``(K,)``). The number and identity of species is configurable via :func:`build_exojax_engines`. bulk: Optional :class:`BulkGasResidual` (implicit residual gas). Components (each owns its own per-planet PyTree leaves where applicable): tp_profile: T-P profile component (e.g. ``PowerLawTPProfile``). absorption: Absorption orchestrator (e.g. ``Absorption``). scattering: Scattering component (e.g. ``RayleighScattering``, ``NullScattering``). clouds: Cloud component (e.g. ``GrayCloud``, ``NoCloud``). surface: Surface component (e.g. ``WavelengthDependentSurface``). Shared / configuration attributes: rt_engine: ExoJAX ``ArtReflectPure`` instance. nu_grid: Wavenumber grid [cm^-1]. n_nu: Length of ``nu_grid`` (static for JIT). """ # Top-level per-planet state log_gravity_cgs: Array # Atmospheric composition species: tuple[MolecularSpecies, ...] bulk: BulkGasResidual | None # Components tp_profile: AbstractTPProfile absorption: AbstractAbsorption scattering: AbstractScattering clouds: AbstractClouds surface: AbstractSurface # Static / shared rt_engine: Any = eqx.field(static=True) nu_grid: Array n_nu: int = eqx.field(static=True)
[docs] @classmethod def from_default_setup( cls, *, log_mmrs: dict[str, Array], T_eq_K: Array, T_alpha: Array, log_surface_albedo: Array, log_gravity_cgs: Array, log_cloud_pressure_bar: Array | None = None, log_cloud_opt_depth: Array | None = None, surface_albedo_spectrum: Array | None = None, molecules: tuple[str, ...] | None = None, bulk_gas: str | None = "N2", wavelength_min_nm: float = 400.0, wavelength_max_nm: float = 1000.0, n_wavenumbers: int = 2000, n_layers: int = 100, pressure_top_bar: float = 1.0e-5, pressure_btm_bar: float = 1.0e0, databases_dir: str | None = None, crit: float = 0.0, ) -> ExoJaxPhysicalModel: """One-shot convenience: build engines + default components in one call. Defaults to Earth-like physics: ``PowerLawTPProfile``, ``Absorption``, ``RayleighScattering`` with N2 bulk, ``GrayCloud`` (Earth water clouds), and a flat ``WavelengthDependentSurface``. Args: log_mmrs: Dict mapping molecule name to per-planet log10 mass-mixing ratio, shape ``(K,)`` each. The dict's molecules determine which species are built. T_eq_K: Per-planet equatorial T [K], ``(K,)``. T_alpha: Per-planet T-P power-law exponent, ``(K,)``. log_surface_albedo: Per-planet surface scaling, ``(K,)``. log_gravity_cgs: Per-planet log10 gravity, ``(K,)``. log_cloud_pressure_bar: Per-planet cloud-deck pressure, ``(K,)``. log_cloud_opt_depth: Per-planet cloud optical depth, ``(K,)``. surface_albedo_spectrum: ``(n_nu,)`` spectral profile. Defaults to flat ones. molecules: Override molecule list (default: derived from ``log_mmrs.keys()``). bulk_gas: Implicit residual gas (default ``"N2"``). wavelength_min_nm: See :func:`build_exojax_engines`. wavelength_max_nm: See :func:`build_exojax_engines`. n_wavenumbers: See :func:`build_exojax_engines`. n_layers: See :func:`build_exojax_engines`. pressure_top_bar: See :func:`build_exojax_engines`. pressure_btm_bar: See :func:`build_exojax_engines`. databases_dir: See :func:`build_exojax_engines`. crit: See :func:`build_exojax_engines`. """ if molecules is None: molecules = tuple(log_mmrs.keys()) engines = build_exojax_engines( molecules=molecules, bulk_gas=bulk_gas, wavelength_min_nm=wavelength_min_nm, wavelength_max_nm=wavelength_max_nm, n_wavenumbers=n_wavenumbers, n_layers=n_layers, pressure_top_bar=pressure_top_bar, pressure_btm_bar=pressure_btm_bar, databases_dir=databases_dir, crit=crit, ) K = log_gravity_cgs.shape[0] if log_cloud_pressure_bar is None: log_cloud_pressure_bar = jnp.full((K,), jnp.log10(0.5)) if log_cloud_opt_depth is None: log_cloud_opt_depth = jnp.full((K,), jnp.log10(5.0)) if surface_albedo_spectrum is None: surface_albedo_spectrum = jnp.ones(engines["n_nu"]) species = _assemble_species(log_mmrs, engines["species_prebuilt"]) bulk = _assemble_bulk(engines["bulk_prebuilt"]) return cls( log_gravity_cgs=log_gravity_cgs, species=species, bulk=bulk, tp_profile=PowerLawTPProfile(T_eq_K=T_eq_K, T_alpha=T_alpha), absorption=Absorption(), scattering=RayleighScattering(), clouds=GrayCloud( log_pressure_bar=log_cloud_pressure_bar, log_opt_depth=log_cloud_opt_depth, ), surface=WavelengthDependentSurface( log_albedo=log_surface_albedo, spectrum=surface_albedo_spectrum, ), rt_engine=engines["rt_engine"], nu_grid=engines["nu_grid"], n_nu=engines["n_nu"], )
[docs] @classmethod def from_default_setup_cached( cls, *, log_mmrs: dict[str, Array | AbstractMmrProfile], T_eq_K: Array, T_alpha: Array, log_surface_albedo: Array, log_gravity_cgs: Array, log_cloud_pressure_bar: Array | None = None, log_cloud_opt_depth: Array | None = None, surface_albedo_spectrum: Array | None = None, molecules: tuple[str, ...] | None = None, bulk_gas: str | None = "N2", wavelength_min_nm: float = 400.0, wavelength_max_nm: float = 1000.0, n_wavenumbers: int = 2000, n_layers: int = 100, pressure_top_bar: float = 1.0e-5, pressure_btm_bar: float = 1.0e0, databases_dir: str | None = None, crit: float = 0.0, cache_dir: str | Path | None = None, ) -> PrecomputedPhysicalModel: """Cached one-shot factory: returns a fast :class:`PrecomputedPhysicalModel`. Hashes every input and looks up a cached spectrum on disk. On cache hit, returns the cached spectrum in ~10 ms. On cache miss, builds the full :class:`ExoJaxPhysicalModel` via :meth:`from_default_setup`, runs the 2-stream RT once, precomputes the reflectivity, saves to disk, and returns the :class:`PrecomputedPhysicalModel`. Use this when the physical-model parameters are fixed across many evaluations (coronagraphoto sims, ETC studies). For HMC retrievals where parameters vary, use :meth:`from_default_setup` directly. Args: cache_dir: Override the cache directory. Defaults to ``~/.cache/skyscapes/physical_models/``. log_mmrs: See :meth:`from_default_setup`. T_eq_K: See :meth:`from_default_setup`. T_alpha: See :meth:`from_default_setup`. log_surface_albedo: See :meth:`from_default_setup`. log_gravity_cgs: See :meth:`from_default_setup`. log_cloud_pressure_bar: See :meth:`from_default_setup`. log_cloud_opt_depth: See :meth:`from_default_setup`. surface_albedo_spectrum: See :meth:`from_default_setup`. molecules: See :meth:`from_default_setup`. bulk_gas: See :meth:`from_default_setup`. wavelength_min_nm: See :meth:`from_default_setup`. wavelength_max_nm: See :meth:`from_default_setup`. n_wavenumbers: See :meth:`from_default_setup`. n_layers: See :meth:`from_default_setup`. pressure_top_bar: See :meth:`from_default_setup`. pressure_btm_bar: See :meth:`from_default_setup`. databases_dir: See :meth:`from_default_setup`. crit: See :meth:`from_default_setup`. Returns: A :class:`PrecomputedPhysicalModel` with no RT cost at evaluation time. """ kwargs = { "log_mmrs": log_mmrs, "T_eq_K": T_eq_K, "T_alpha": T_alpha, "log_surface_albedo": log_surface_albedo, "log_gravity_cgs": log_gravity_cgs, "log_cloud_pressure_bar": log_cloud_pressure_bar, "log_cloud_opt_depth": log_cloud_opt_depth, "surface_albedo_spectrum": surface_albedo_spectrum, "molecules": molecules, "bulk_gas": bulk_gas, "wavelength_min_nm": wavelength_min_nm, "wavelength_max_nm": wavelength_max_nm, "n_wavenumbers": n_wavenumbers, "n_layers": n_layers, "pressure_top_bar": pressure_top_bar, "pressure_btm_bar": pressure_btm_bar, "crit": crit, } if cache_dir is None: cache_dir = Path.home() / ".cache" / "skyscapes" / "physical_models" cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) cache_path = cache_dir / f"{_cache_key(**kwargs)}.npz" if cache_path.exists(): return PrecomputedPhysicalModel.load(cache_path) model = cls.from_default_setup(databases_dir=databases_dir, **kwargs) cached = PrecomputedPhysicalModel.from_physical_model(model) cached.save(cache_path) return cached
[docs] def _reflectivity_one_planet(self, k: int) -> Array: """Plane-parallel reflectivity for the k-th planet, shape ``(n_nu,)``. Indexes every K-shaped leaf at position ``k`` (cheap; JIT traces this as plain array indexing). With per-species mmr profiles the previous "vmap over an extracted log_mmrs array" pattern broke because different profile types have different K-shaped fields -- a Python loop over K avoids that issue without losing JIT efficiency at K=1 (the common case). """ pressure = self.rt_engine.pressure n_layers = pressure.shape[0] log_gravity_cgs_scalar = self.log_gravity_cgs[k] gravity = 10.0**log_gravity_cgs_scalar Tarr = self.tp_profile.compute_Tarr( self.rt_engine, self.tp_profile.T_eq_K[k], self.tp_profile.T_alpha[k], ) # Index each species' profile at planet k. Tree-map replaces # K-shaped array leaves with their k-th slice while leaving # static fields (name, molmass, opa) untouched. def _index_at_k(tree): return jax.tree.map( lambda x: x[k] if isinstance(x, jnp.ndarray) and x.ndim > 0 else x, tree, ) species_one = tuple( eqx.tree_at(lambda s: s.profile, s, replace=_index_at_k(s.profile)) for s in self.species ) abs_c = self.absorption.compute( species_one, Tarr, pressure, gravity, self.rt_engine ) scat_c = self.scattering.compute( species_one, self.bulk, gravity, self.rt_engine, n_layers, self.n_nu ) cloud_c = self.clouds.compute( self.clouds.log_pressure_bar[k], self.clouds.log_opt_depth[k], pressure, self.n_nu, ) dtau_total = abs_c.dtau_total + scat_c.dtau_total + cloud_c.dtau_total dtau_scatter = abs_c.dtau_scatter + scat_c.dtau_scatter + cloud_c.dtau_scatter g_num = abs_c.g_weighted_num + scat_c.g_weighted_num + cloud_c.g_weighted_num ssa = jnp.where(dtau_total > 0, dtau_scatter / dtau_total, 0.0) g_asym = jnp.where(dtau_scatter > 0, g_num / dtau_scatter, 0.0) refl_surface = self.surface.compute_refl(self.surface.log_albedo[k], self.n_nu) F_unit = jnp.ones(self.n_nu) return self.rt_engine.run(dtau_total, ssa, g_asym, refl_surface, F_unit)
[docs] def _reflectivity_all_planets(self) -> Array: """Plane-parallel reflectivity for every planet, shape ``(K, n_nu)``. Loops over K planets in Python. JIT unrolls the loop; for the common K=1 case this is a single iteration with no overhead. For HMC retrievals (K=1 per sample), the JIT cache is reused across MCMC steps. """ K = int(self.log_gravity_cgs.shape[0]) return jnp.stack([self._reflectivity_one_planet(k) for k in range(K)], axis=0)
[docs] def contrast( self, phase_angle_rad: Array, dist_AU: Array, wavelength_nm: Array, Rp_Rearth: Array, ) -> Array: """Per-planet, per-time geometric-albedo contrast at one wavelength. Args: phase_angle_rad: Star-planet-observer phase angle, ``(K, T)``. dist_AU: Star-planet distance [AU], ``(K, T)``. wavelength_nm: Scalar wavelength [nm]. Rp_Rearth: Planet radius [Earth radii], shape ``(K,)``. Returns: Contrast = ``A_g(lambda) * Lambert_phase(beta) * (Rp/d)^2``, shape ``(K, T)``. The 2-stream RT produces the plane- parallel (spherical) reflectivity; we convert to geometric albedo via the Lambertian-sphere factor 2/3 (Seager 2010, eq 3.36) so the output is a geometric-albedo contrast -- the same convention as :class:`LambertianPhysicalModel`. """ R_per_planet = self._reflectivity_all_planets() target_nu_cm_inv = nm_to_wavenumber_cm(wavelength_nm) def interp_one(spectrum): return interpax.interp1d( target_nu_cm_inv, self.nu_grid, spectrum, method="cubic", extrap=True, ) R_at_wl = jax.vmap(interp_one)(R_per_planet) Rp_AU = (Rp_Rearth * Rearth2AU)[:, None] Ag = spherical_to_geometric_albedo(R_at_wl)[:, None] phase = _lambert_phase(phase_angle_rad) return Ag * phase * (Rp_AU / dist_AU) ** 2
[docs] def contrast_cube( self, phase_angle_rad: Array, dist_AU: Array, wavelengths_nm: Array, Rp_Rearth: Array, ) -> Array: """Per-planet, per-time contrast across many wavelengths. Returns ``(W, K, T)`` geometric-albedo contrast cube. Computes the underlying 2-stream RT exactly once per planet rather than once per wavelength; applies the spherical-to-geometric Lambertian-sphere conversion (Seager 2010, eq 3.36) the same way as :meth:`contrast`. """ R_per_planet = self._reflectivity_all_planets() target_nu_cm_inv = nm_to_wavenumber_cm(wavelengths_nm) def interp_one_planet(spectrum): return interpax.interp1d( target_nu_cm_inv, self.nu_grid, spectrum, method="cubic", extrap=True, ) R_at_wls = spherical_to_geometric_albedo( jax.vmap(interp_one_planet)(R_per_planet) ) # (K, W) Rp_AU = (Rp_Rearth * Rearth2AU)[:, None] phase = _lambert_phase(phase_angle_rad) geom = phase * (Rp_AU / dist_AU) ** 2 return jnp.einsum("kw,kt->wkt", R_at_wls, geom)
[docs] def __repr__(self) -> str: """Human-readable summary of components + species + per-planet state.""" K = int(self.log_gravity_cgs.shape[0]) n_layers = int(self.rt_engine.pressure.shape[0]) wl_min_nm = float(1.0e7 / self.nu_grid.max()) wl_max_nm = float(1.0e7 / self.nu_grid.min()) bulk_label = self.bulk.name if self.bulk is not None else "None" species_label = ", ".join(s.name for s in self.species) lines = [ f"ExoJaxPhysicalModel(K={K})", ( f" Wavelength: {wl_min_nm:.0f}-{wl_max_nm:.0f} nm " f"(n_nu={self.n_nu}, n_layers={n_layers})" ), f" Composition: [{species_label}] + bulk={bulk_label}", " Components:", indent(f"tp: {type(self.tp_profile).__name__}", prefix=" "), indent( f"absorption: {type(self.absorption).__name__}", prefix=" ", ), indent( f"scattering: {type(self.scattering).__name__}", prefix=" ", ), indent(f"clouds: {type(self.clouds).__name__}", prefix=" "), indent(f"surface: {type(self.surface).__name__}", prefix=" "), ] max_show = 3 pressure = self.rt_engine.pressure # Display MMR at the bottom-of-atmosphere layer (typically 1 bar) # as the "headline" value -- this matches the surface/tropospheric # value most readers expect when they see "Earth H2O ~ 3e-3". botm_layer_idx = int(jnp.argmax(pressure)) for k in range(min(K, max_show)): # Index each species' profile at planet k, then evaluate # at the bottom-of-atmosphere pressure for the display. mmrs_k = [] profile_labels = [] def _index_k(x, k=k): return x[k] if isinstance(x, jnp.ndarray) and x.ndim > 0 else x for s in self.species: profile_k = jax.tree.map(_index_k, s.profile) mmr_at_botm = float(profile_k.evaluate(pressure)[botm_layer_idx]) mmrs_k.append(mmr_at_botm) profile_labels.append(type(s.profile).__name__) mmr_bulk_k = max(0.0, 1.0 - sum(mmrs_k)) inv_m_mean = sum( mmr / s.molmass for s, mmr in zip(self.species, mmrs_k, strict=True) ) if self.bulk is not None: inv_m_mean += mmr_bulk_k / self.bulk.molmass m_mean = 1.0 / inv_m_mean if inv_m_mean > 0 else 28.0 g = 10.0 ** float(self.log_gravity_cgs[k]) t_eq = float(self.tp_profile.T_eq_K[k]) alpha = float(self.tp_profile.T_alpha[k]) ag = 10.0 ** float(self.surface.log_albedo[k]) p_cl = 10.0 ** float(self.clouds.log_pressure_bar[k]) tau_cl = 10.0 ** float(self.clouds.log_opt_depth[k]) lines += [ f" Planet {k}:", f" g = {g:.0f} cm/s^2", f" T(P=1bar) = {t_eq:.1f} K, alpha = {alpha:.3f}", f" Cloud: P = {p_cl:.2g} bar, tau = {tau_cl:.2g}", f" Surface albedo (scalar) = {ag:.3f}", ( f" Mixing ratios at bottom layer " f"(MMR / VMR, M_mean = {m_mean:.2f} g/mol):" ), ] for s, mmr, lbl in zip(self.species, mmrs_k, profile_labels, strict=True): vmr = mmr * m_mean / s.molmass lines.append(f" {s.name:<5s} {mmr:.3e} / {vmr:.3e} [{lbl}]") if self.bulk is not None: vmr_bulk = mmr_bulk_k * m_mean / self.bulk.molmass lines.append( f" {self.bulk.name}* {mmr_bulk_k:.3e} / " f"{vmr_bulk:.3e} (* implicit residual)" ) if K > max_show: lines.append(f" ... and {K - max_show} more planets") return "\n".join(lines)