"""GraterDisk: Augereau 1999 scattered-light disk as an eqx.Module.
The Augereau et al. 1999 model is a two-power-law dust density distribution
with a flared vertical profile, scattering with a single Henyey-Greenstein
phase function. This module exposes the model in skyscapes' eqx.Module
idiom so its parameters are PyTree leaves (fittable through the standard
scene-as-PyTree pattern).
References:
Augereau, J.-C., Lagrange, A.-M., Mouillet, D., Papaloizou, J. C. B.,
& Grorod, P. A. 1999, A&A, 348, 557
grater-jax (Kondapalli, Lewis, Ashcraft, Millar-Blanchaer):
https://github.com/UCSB-Exoplanet-Polarimetry-Lab/GRaTeR-JAX
Notes on convention:
Disk orientation (``incl_deg`` / ``pa_deg``) is passed in at
``surface_brightness`` time rather than stored on the disk -- the
System's ``midplane_inc_deg`` / ``midplane_pa_deg`` drives every
component consistently and there is one source of truth.
Wavelength dependence enters via the HG asymmetry parameter ``g_HG``
and the albedo-like scaling ``Ag``, both stored as 1-D grids over
``wavelengths_nm`` and interpolated at call time with a cubic
spline. The geometric kernel itself is wavelength-independent, so
``surface_brightness`` keeps its scalar-in / ``(ny, nx)``-out shape
contract; ``vmap`` over wavelength gives a cube without changing
the interface.
The LOS-around-midplane sampling diverges at edge-on geometry
(``cos(incl) -> 0``); ``surface_brightness`` raises via
``eqx.error_if`` if the requested inclination is closer to edge-on
than the disk's geometric capacity allows. The threshold is
``arctan(rmax_AU / zmax_AU)`` where ``zmax_AU`` is evaluated at
the outer disk radius after flaring.
"""
from __future__ import annotations
import equinox as eqx
import interpax
import jax.numpy as jnp
from jaxtyping import Array
from .._repr import fmt_scalar_or_array
from ._los import los_integrate_scattered
from .base import AbstractDisk
[docs]
def henyey_greenstein(cos_phi: Array, g: Array) -> Array:
"""Henyey-Greenstein single-scattering phase function.
Args:
cos_phi: Cosine of scattering angle, any broadcastable shape.
g: Asymmetry parameter in (-1, 1). g > 0 is forward-peaked,
g < 0 is backward-peaked, g = 0 is isotropic.
Returns:
Phase function value, same shape as ``cos_phi``.
"""
one_minus_g2 = 1.0 - g * g
# Explicit x * sqrt(x) is faster than pow(x, 1.5) on CPU.
inner = 1.0 + g * g - 2.0 * g * cos_phi
denom = inner * jnp.sqrt(inner)
return one_minus_g2 / (4.0 * jnp.pi * denom)
[docs]
def two_power_law_density(
r: Array,
z: Array,
sma_AU: Array,
alpha_in: Array,
alpha_out: Array,
ksi0_AU: Array,
gamma: Array,
beta: Array,
) -> Array:
"""Augereau 1999 two-power-law radial density with flared vertical profile.
Args:
r: Cylindrical radius in disk frame [AU]. Must be strictly positive;
callers are responsible for masking r values that are too small
for the radial power law to evaluate cleanly in the chosen dtype
(e.g. r_ratio^(-2 alpha_in) overflows float32 below r/sma ~ 1e-3
at alpha_in = 5).
z: Height above disk midplane [AU].
sma_AU: Reference radius (radial profile peak) [AU].
alpha_in: Inner power-law slope (positive convention).
alpha_out: Outer power-law slope (negative convention).
ksi0_AU: Vertical scale height at ``r = sma_AU`` [AU].
gamma: Vertical profile exponent (2 = Gaussian, 1 = exponential).
beta: Flaring index (0 = none, 1 = linear).
Returns:
Density in arbitrary units (relative), same shape as ``r`` and ``z``.
"""
r_ratio = r / sma_AU
radial = (r_ratio ** (-2.0 * alpha_in) + r_ratio ** (-2.0 * alpha_out)) ** (-0.5)
h_r = ksi0_AU * r_ratio**beta
vertical = jnp.exp(-((jnp.abs(z) / h_r) ** gamma))
return radial * vertical
[docs]
class GraterDisk(AbstractDisk):
"""Augereau 1999 scattered-light disk.
Attributes (PyTree leaves, fittable):
sma_AU: Reference radius [AU] (radial profile peak).
alpha_in: Inner power-law slope (positive).
alpha_out: Outer power-law slope (negative).
ksi0_AU: Vertical scale height at ``sma_AU`` [AU].
gamma: Vertical profile exponent (2 = Gaussian).
beta: Flaring index (0 = none, 1 = linear).
rmin_AU: Inner truncation radius [AU].
rmax_AU: Outer truncation radius [AU].
wavelengths_nm: 1-D wavelength grid for ``g_HG_grid`` /
``Ag_grid``, shape ``(n_wave,)`` [nm]. Must be sorted.
``surface_brightness`` calls outside this range return NaN.
g_HG_grid: Henyey-Greenstein asymmetry at each grid wavelength.
Ag_grid: Albedo-like scaling at each grid wavelength.
Static attributes (compilation-time constants):
nx, ny: Output image shape.
pixel_scale_arcsec: Pixel scale [arcsec/pixel].
dist_pc: System distance [pc].
n_slices_los: Number of LOS integration slices.
n_scale_heights: LOS half-extent in units of the local scale
height at ``rmax_AU`` (default 6.0).
"""
sma_AU: Array
alpha_in: Array
alpha_out: Array
ksi0_AU: Array
gamma: Array
beta: Array
rmin_AU: Array
rmax_AU: Array
wavelengths_nm: Array
g_HG_grid: Array
Ag_grid: Array
nx: int = eqx.field(static=True)
ny: int = eqx.field(static=True)
pixel_scale_arcsec: float = eqx.field(static=True)
dist_pc: float = eqx.field(static=True)
n_slices_los: int = eqx.field(static=True)
n_scale_heights: float = eqx.field(static=True, default=6.0)
[docs]
def _zmax_AU(self) -> Array:
"""LOS half-extent in disk-frame z, evaluated at ``rmax_AU``.
For a flared disk, the scale height grows with radius, so the
LOS depth that captures the vertical tails at the outer edge
is larger than ``n_scale_heights * ksi0_AU`` at ``sma_AU``.
"""
return (
self.n_scale_heights
* self.ksi0_AU
* (self.rmax_AU / self.sma_AU) ** self.beta
)
[docs]
def surface_brightness(
self,
wavelength_nm: Array,
time_jd: Array,
incl_deg: Array,
pa_deg: Array,
) -> Array:
"""Render the Augereau 1999 scattered-light disk on a sky-plane grid.
``wavelength_nm`` selects ``g_HG`` and ``Ag`` via cubic-spline
interpolation over ``wavelengths_nm`` / ``g_HG_grid`` /
``Ag_grid``. Queries outside the grid return NaN.
``time_jd`` is part of the AbstractDisk interface but ignored
(static disk).
"""
# Wavelength-dependent scattering coefficients.
g_HG = interpax.interp1d(
wavelength_nm,
self.wavelengths_nm,
self.g_HG_grid,
method="cubic",
extrap=False,
)
Ag = interpax.interp1d(
wavelength_nm,
self.wavelengths_nm,
self.Ag_grid,
method="cubic",
extrap=False,
)
# Geometric edge-on check. Threshold = arctan(rmax / zmax(rmax))
# is the inclination beyond which the LOS at the outer disk edge
# passes through more than the disk's full vertical extent and
# the parameterization breaks.
zmax_AU = self._zmax_AU()
threshold_deg = jnp.rad2deg(jnp.arctan(self.rmax_AU / zmax_AU))
dist_from_edge_on = jnp.abs((incl_deg % 180.0) - 90.0)
incl_deg = eqx.error_if(
incl_deg,
dist_from_edge_on < (90.0 - threshold_deg),
"GraterDisk: incl_deg too close to edge-on for this disk's "
"geometry (|incl - 90| < 90 - arctan(rmax_AU / zmax_AU)). "
"Increase n_scale_heights * ksi0_AU * (rmax/sma)^beta, "
"decrease rmax_AU, or move incl_deg further from 90.",
)
# Density closure: substitute sma_AU outside the valid annulus
# so the (-2 alpha_in) power stays bounded in float32; the kernel
# masks these substituted contributions back to zero.
def density_fn(r_AU: Array, z_AU: Array, valid: Array) -> Array:
r_safe = jnp.where(valid, r_AU, self.sma_AU)
z_safe = jnp.where(valid, z_AU, jnp.zeros_like(z_AU))
return two_power_law_density(
r_safe,
z_safe,
self.sma_AU,
self.alpha_in,
self.alpha_out,
self.ksi0_AU,
self.gamma,
self.beta,
)
def phase_fn(cos_phi: Array) -> Array:
return henyey_greenstein(cos_phi, g_HG)
integral = los_integrate_scattered(
density_fn,
phase_fn,
incl_deg=incl_deg,
pa_deg=pa_deg,
rmin_AU=self.rmin_AU,
rmax_AU=self.rmax_AU,
zmax_AU=zmax_AU,
nx=self.nx,
ny=self.ny,
pixel_scale_arcsec=self.pixel_scale_arcsec,
dist_pc=self.dist_pc,
n_slices_los=self.n_slices_los,
)
return Ag * integral
[docs]
def spatial_extent(self) -> tuple[float, float]:
"""Return ``(width_arcsec, height_arcsec)``."""
return (
self.nx * self.pixel_scale_arcsec,
self.ny * self.pixel_scale_arcsec,
)
[docs]
def __repr__(self) -> str:
"""Compact summary of radial/vertical/HG parameters + image/Ag grid."""
n_wl = int(self.wavelengths_nm.shape[0])
wl_min = float(self.wavelengths_nm.min())
wl_max = float(self.wavelengths_nm.max())
return (
f"GraterDisk(sma={fmt_scalar_or_array(self.sma_AU)} AU, "
f"radial: alpha_in={fmt_scalar_or_array(self.alpha_in)}, "
f"alpha_out={fmt_scalar_or_array(self.alpha_out)}, "
f"r=[{fmt_scalar_or_array(self.rmin_AU)}, "
f"{fmt_scalar_or_array(self.rmax_AU)}] AU; "
f"vertical: ksi0={fmt_scalar_or_array(self.ksi0_AU)} AU, "
f"beta={fmt_scalar_or_array(self.beta)}, "
f"gamma={fmt_scalar_or_array(self.gamma)}; "
f"HG/Ag grid: {n_wl} pts in {wl_min:.0f}-{wl_max:.0f} nm; "
f"image: {self.ny}x{self.nx} @ {self.pixel_scale_arcsec} arcsec/px, "
f"dist={self.dist_pc} pc, n_slices_los={self.n_slices_los})"
)