"""Shared LOS-around-midplane integration kernel for parametric disks.
Lifted out of ``GraterDisk.surface_brightness`` so multiple disk classes
(GraterDisk, ExovistaParametricDisk, ...) can plug in their own density
profiles and phase functions while sharing the geometric integration.
"""
from __future__ import annotations
from collections.abc import Callable
import jax.numpy as jnp
from hwoutils.constants import deg2rad
from hwoutils.transforms import ccw_rotation_matrix
from jaxtyping import Array
DensityFn = Callable[[Array, Array, Array], Array]
PhaseFn = Callable[[Array], Array]
[docs]
def los_integrate_scattered(
density_fn: DensityFn,
phase_fn: PhaseFn,
*,
incl_deg: Array,
pa_deg: Array,
rmin_AU: Array,
rmax_AU: Array,
zmax_AU: Array,
nx: int,
ny: int,
pixel_scale_arcsec: float,
dist_pc: float,
n_slices_los: int,
) -> Array:
"""Integrate scattered light along LOS around the disk midplane.
Sky frame: ``+x`` right, ``+y`` toward N, observer along ``+z``.
``pa_deg`` is measured from N toward the disk's projected major
axis (CCW). ``incl_deg`` is the angle between the disk normal and
the observer's line of sight (0 = pole-on).
Args:
density_fn: ``(r_AU, z_AU, valid) -> rho``. The kernel ensures
``r_AU`` is finite everywhere (sqrt is masked) but the
caller must substitute density-safe values at invalid
points (e.g. avoid ``r=0`` for ``r^alpha`` profiles).
phase_fn: ``cos_phi -> phase`` evaluated at every LOS sample.
incl_deg: Disk inclination [deg].
pa_deg: Disk position angle [deg] from N, CCW.
rmin_AU: Inner truncation radius [AU].
rmax_AU: Outer truncation radius [AU].
zmax_AU: LOS half-extent in disk-frame z [AU]. Should comfortably
exceed the disk's largest scale height (for flared disks,
evaluate the scale-height formula at ``rmax_AU``).
nx: Output image width [pixels] (static).
ny: Output image height [pixels] (static).
pixel_scale_arcsec: Pixel scale [arcsec/pixel] (static).
dist_pc: System distance [pc] (static).
n_slices_los: Number of LOS integration slices (static).
Returns:
LOS-integrated scattered-light map, shape ``(ny, nx)``.
Caller multiplies by an albedo / normalization to get contrast.
"""
px_AU = pixel_scale_arcsec * dist_pc
x_pix = (jnp.arange(nx) - (nx - 1) / 2.0) * px_AU
y_pix = (jnp.arange(ny) - (ny - 1) / 2.0) * px_AU
x_sky, y_sky = jnp.meshgrid(x_pix, y_pix)
incl = incl_deg * deg2rad
cos_i = jnp.cos(incl)
sin_i = jnp.sin(incl)
r_pa_inv = ccw_rotation_matrix(-pa_deg)
x_rot = r_pa_inv[0, 0] * x_sky + r_pa_inv[0, 1] * y_sky
y_rot = r_pa_inv[1, 0] * x_sky + r_pa_inv[1, 1] * y_sky
# Per-pixel midplane crossing in LOS depth, then sample a fixed
# disk-frame z extent around it.
l_mid = y_rot * sin_i / cos_i
l_half = zmax_AU / cos_i
t = jnp.linspace(-1.0, 1.0, n_slices_los)
l_grid = l_mid[None, :, :] + t[:, None, None] * l_half
y_d = y_rot[None, :, :] * cos_i + l_grid * sin_i
z_d = -y_rot[None, :, :] * sin_i + l_grid * cos_i
xy_sq = x_rot * x_rot + y_d * y_d
d_star_sq = xy_sq + z_d * z_d
rmin_sq = rmin_AU * rmin_AU
rmax_sq = rmax_AU * rmax_AU
valid = (xy_sq >= rmin_sq) & (xy_sq <= rmax_sq) & (d_star_sq > 0.0)
# Sqrt only inside the valid annulus; the caller's density_fn handles
# any further safe substitution before evaluating its formula.
xy_sq_safe = jnp.where(valid, xy_sq, 1.0)
r_AU = jnp.sqrt(xy_sq_safe)
rho = density_fn(r_AU, z_d, valid)
# Scattering angle: observer-frame z / |position| = l_grid / |position|.
safe_d_sq = jnp.where(d_star_sq > 0.0, d_star_sq, 1.0)
safe_d = jnp.sqrt(safe_d_sq)
cos_phi = l_grid / safe_d
phase = phase_fn(cos_phi)
integrand = jnp.where(valid, rho * phase / safe_d_sq, 0.0)
dl = 2.0 * l_half / (n_slices_los - 1)
return 0.5 * dl * (integrand[:-1] + integrand[1:]).sum(axis=0)