Source code for skyscapes.physical_model.exojax.components.mie_cloud

"""Mie-scattering cloud component.

A drop-in replacement for :class:`GrayCloud` that uses ExoJAX's
``OpaMie`` to pre-compute wavelength-dependent single-scattering
albedo and asymmetry parameter for a chosen condensate (e.g. water,
water-ice, NH3) and particle size distribution. The cloud's vertical
distribution and total optical depth remain fittable PyTree leaves;
the wavelength dependence of ssa(lambda) and g(lambda) is fixed at
engine-build time.

ExoJAX's ``PdbCloud`` triggers a one-time download of the chosen
condensate's refractive-index data (cached under
``./.database/particulates/virga/``) on first use, similar to the
HITRAN line-list downloads.
"""

from __future__ import annotations

import equinox as eqx
import jax
import jax.numpy as jnp
import scipy.integrate
from jaxtyping import Array

# PyMieScatt (loaded lazily by ExoJAX inside ``generate_miegrid``) calls
# ``scipy.integrate.trapz``, which was removed in SciPy 1.14 (renamed
# to ``trapezoid``). Patch it back before any PyMieScatt-touching
# ExoJAX code runs.
if not hasattr(scipy.integrate, "trapz"):
    scipy.integrate.trapz = scipy.integrate.trapezoid

from exojax.database.pardb import PdbCloud
from exojax.opacity.opacont import OpaMie

from .base import AbstractClouds, Contribution
from .clouds import DEFAULT_CLOUD_LOG_SIGMA


[docs] class MieCloud(AbstractClouds): """Single-layer cloud with Mie-scattering optical properties. The total cloud scattering optical depth is distributed vertically via a softmax-Gaussian in log-pressure (same as :class:`GrayCloud`). What's different is that the single-scattering albedo and asymmetry parameter come from pre-computed Mie cross-sections rather than being scalar constants -- they vary with wavelength according to the condensate's refractive index n(lambda) + k(lambda). Attributes (PyTree leaves, fittable): log_pressure_bar: Log10 cloud-deck pressure [bar], shape ``(K,)``. log_opt_depth: Log10 of the vertically-integrated cloud extinction optical depth, shape ``(K,)``. Pre-computed Mie quantities (built at engine time, shared across planets): ssa_grid: Single-scattering albedo on the wavenumber grid, shape ``(n_nu,)``. ``sigma_scattering / sigma_extinction``. g_grid: Asymmetry parameter on the wavenumber grid, shape ``(n_nu,)``. Static config: condensate: Condensate name (``"H2O"``, ``"H2O_ice"``, ``"MgSiO3"``, etc.). Matters for the repr; the actual Mie params are baked into ssa_grid and g_grid. rg_um: Geometric mean particle radius [um]. sigmag: Geometric standard deviation of the lognormal size distribution. log_sigma: Cloud-deck vertical Gaussian half-width in log10(pressure) [dex]. """ log_pressure_bar: Array log_opt_depth: Array ssa_grid: Array g_grid: Array condensate: str = eqx.field(static=True, default="H2O") rg_um: float = eqx.field(static=True, default=10.0) sigmag: float = eqx.field(static=True, default=2.0) log_sigma: float = eqx.field(static=True, default=DEFAULT_CLOUD_LOG_SIGMA)
[docs] def compute( self, log_pressure_bar_scalar: Array, log_opt_depth_scalar: Array, pressure: Array, n_nu: int, ) -> Contribution: """Mie-cloud contribution: gray-depth distribution, spectral ssa/g.""" n_layers = pressure.shape[0] log_p_layers = jnp.log10(pressure) log_w = -0.5 * ((log_p_layers - log_pressure_bar_scalar) / self.log_sigma) ** 2 cloud_weights = jax.nn.softmax(log_w) tau_total = 10.0**log_opt_depth_scalar dtau_layer = tau_total * cloud_weights # (n_layers,) dtau = jnp.broadcast_to(dtau_layer[:, None], (n_layers, n_nu)) # Wavelength-dependent ssa(lambda), g(lambda) broadcast across # layers. Unlike GrayCloud's scalar ssa/g, Mie clouds vary the # scattering properties across the spectrum. ssa_per_nu = jnp.broadcast_to(self.ssa_grid[None, :], dtau.shape) g_per_nu = jnp.broadcast_to(self.g_grid[None, :], dtau.shape) dtau_scatter = ssa_per_nu * dtau g_weighted_num = g_per_nu * dtau_scatter return Contribution( dtau_total=dtau, dtau_scatter=dtau_scatter, g_weighted_num=g_weighted_num, )
[docs] def build_mie_cloud( *, nu_grid: Array, log_pressure_bar: Array, log_opt_depth: Array, condensate: str = "H2O", rg_um: float = 10.0, sigmag: float = 2.0, log_sigma: float = DEFAULT_CLOUD_LOG_SIGMA, ) -> MieCloud: """Build a Mie cloud component by pre-computing Mie params. On first use for a given condensate this triggers two downloads / computations cached under ``./.database/particulates/virga/``: - **Refractive-index file** (small, fetched from Zenodo). - **Mie-grid lookup table** (built locally via PyMieScatt; takes a couple of minutes per condensate). Subsequent calls reuse the cache. Args: nu_grid: Wavenumber grid [cm^-1] from :func:`build_exojax_engines`. log_pressure_bar: Per-planet log10 cloud pressure [bar]. log_opt_depth: Per-planet log10 total cloud extinction tau. condensate: Condensate name. ExoJAX/Virga ships e.g. ``"H2O"``, ``"H2O_ice"``, ``"NH3"``, ``"MgSiO3"``, ``"Mg2SiO4"``, ``"Fe"``, ``"KCl"``, ``"Na2S"``, ``"ZnS"``, ``"MnS"``, ``"Cr"``, ``"Al2O3"``, ``"TiO2"``. rg_um: Mean particle radius [um] of the lognormal size distribution. sigmag: Geometric standard deviation of the size distribution. log_sigma: Vertical-distribution Gaussian half-width [dex]. Returns: A :class:`MieCloud` instance ready to drop into :class:`ExoJaxPhysicalModel`. """ pdb = PdbCloud(condensate=condensate) if not pdb.miegrid_path.exists(): pdb.generate_miegrid() pdb.load_miegrid() opa_mie = OpaMie(pdb=pdb, nu_grid=nu_grid) rg_cm = rg_um * 1.0e-4 sig_ext, sig_sca, g_arr = opa_mie.mieparams_vector(rg_cm, sigmag) # ssa = sigma_scattering / sigma_extinction; clipped to avoid # 0/0 when the cross-section is zero outside coverage. ssa_grid = jnp.where(sig_ext > 0, sig_sca / sig_ext, 0.0) return MieCloud( log_pressure_bar=log_pressure_bar, log_opt_depth=log_opt_depth, ssa_grid=ssa_grid, g_grid=g_arr, condensate=condensate, rg_um=rg_um, sigmag=sigmag, log_sigma=log_sigma, )