Source code for skyscapes.physical_model.exojax.components.mmr_profile
"""Mass-mixing-ratio profiles for atmospheric species.
Each species in a :class:`MolecularSpecies` tuple carries one of these
profile components, replacing the previous "scalar mmr per planet"
representation. Different molecules need different altitude variations:
- **Well-mixed gases** (N2, O2, CO2, CH4 in our pressure range): use
:class:`ConstantMmr` -- same mmr at every layer.
- **O3** is concentrated in the stratosphere with column ~10x larger
than a uniformly-mixed model would assume: use
:class:`StratosphericPeakMmr` (Gaussian in log-pressure).
- **H2O** has a sharp cold-trap drop at the tropopause (~0.1 bar)
where temperatures fall below ~200 K: use :class:`TroposphericMmr`
(sigmoid-smoothed step in log-pressure).
The profile contract is just ``evaluate(pressure) -> (n_layers,) mmr``.
Layer pressures come from the rt_engine; each component returns the
per-layer mass-mixing ratio.
"""
from __future__ import annotations
from abc import abstractmethod
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
[docs]
class AbstractMmrProfile(eqx.Module, strict=True):
"""Per-species mass-mixing-ratio profile evaluated at layer pressures."""
[docs]
@abstractmethod
def evaluate(self, pressure: Array) -> Array:
"""Return shape ``(n_layers,)`` mmr values at the layer pressures."""
[docs]
class ConstantMmr(AbstractMmrProfile):
"""Well-mixed gas: constant mmr at every layer.
Use for gases with chemical lifetimes longer than mixing timescales
in the modelled pressure range (CO2, CH4, O2, N2, etc. in
troposphere + lower stratosphere).
Attributes:
log_mmr: Log10 mass-mixing ratio, shape ``(K,)`` (per planet).
"""
log_mmr: Array
[docs]
def evaluate(self, pressure: Array) -> Array:
"""Return ``10**log_mmr`` broadcast to ``pressure.shape``."""
return (10.0**self.log_mmr) * jnp.ones_like(pressure)
[docs]
class StratosphericPeakMmr(AbstractMmrProfile):
"""Gaussian-in-log-pressure peak; canonical for O3.
``mmr(P) = 10**log_peak_mmr * exp(-0.5 * ((log10(P) -
log_peak_pressure_bar) / log_sigma_decades)**2)``.
For Earth's O3: ``log_peak_pressure_bar = log10(0.01) = -2``
(10 mbar, ~30 km altitude), ``log_sigma_decades = 0.5``.
Attributes:
log_peak_mmr: Log10 of the peak mmr, shape ``(K,)``.
log_peak_pressure_bar: Log10 pressure [bar] at peak, ``(K,)``.
log_sigma_decades: Gaussian width in log10-pressure, ``(K,)``.
"""
log_peak_mmr: Array
log_peak_pressure_bar: Array
log_sigma_decades: Array
[docs]
def evaluate(self, pressure: Array) -> Array:
"""Gaussian peak in log-pressure."""
log_p = jnp.log10(pressure)
shape = jnp.exp(
-0.5 * ((log_p - self.log_peak_pressure_bar) / self.log_sigma_decades) ** 2
)
return (10.0**self.log_peak_mmr) * shape
[docs]
class TroposphericMmr(AbstractMmrProfile):
"""Constant below a step pressure, drops sharply above; canonical for H2O.
Uses a sigmoid in log-pressure for smooth transition (differentiable
for HMC retrievals). For Earth's H2O: tropospheric ~3e-3 below
0.1 bar, ~3e-6 above.
Attributes:
log_mmr_below: Log10 mmr at high pressure (below cold trap), ``(K,)``.
log_mmr_above: Log10 mmr at low pressure (above cold trap), ``(K,)``.
log_pressure_step_bar: Log10 transition pressure [bar], ``(K,)``.
log_transition_width_decades: Width of the sigmoid transition
in log10-pressure, ``(K,)``. Smaller = sharper step.
"""
log_mmr_below: Array
log_mmr_above: Array
log_pressure_step_bar: Array
log_transition_width_decades: Array
[docs]
def evaluate(self, pressure: Array) -> Array:
"""Sigmoid step in log-pressure."""
log_p = jnp.log10(pressure)
# Sigmoid weight: ~1 at high pressure (below cold trap),
# ~0 at low pressure (above cold trap).
weight_below = jax.nn.sigmoid(
(log_p - self.log_pressure_step_bar) / self.log_transition_width_decades
)
weight_above = 1.0 - weight_below
return (10.0**self.log_mmr_below) * weight_below + (
10.0**self.log_mmr_above
) * weight_above