"""
Permutation acceleration methods for faster inference.
Implements tail approximation via Generalized Pareto Distribution (GPD)
and gamma approximation, following Winkler et al. (2016) "Faster
permutation inference in brain imaging", NeuroImage.
These methods reduce the number of permutations needed for accurate
p-value estimation from ~5000 to ~200, yielding 10-100x speedup
while maintaining exact error rates for FWER-corrected inference.
References
----------
Winkler et al. (2016). Faster permutation inference in brain imaging.
NeuroImage, doi:10.1016/j.neuroimage.2016.05.068
"""
from __future__ import annotations
import logging
from typing import Dict, Optional
import numpy as np
import numpy.typing as npt
logger = logging.getLogger(__name__)
__all__ = [
"fit_gpd_tail",
"fit_gamma_tail",
"compute_p_values_accelerated",
]
def _fit_gpd_mom(
exceedances: npt.NDArray[np.float64],
) -> tuple[float, float, bool]:
"""
Fit GPD parameters by method of moments.
Parameters
----------
exceedances : ndarray
Positive values y = T - u for T > u (exceedances above threshold u).
Returns
-------
sigma : float
Scale parameter.
xi : float
Shape parameter.
valid : bool
Whether the fit is valid (enough exceedances, finite parameters).
"""
n = len(exceedances)
if n < 10:
return 0.0, 0.0, False
y_bar = np.mean(exceedances)
s2 = np.var(exceedances, ddof=1)
if y_bar <= 0 or s2 <= 0:
return 0.0, 0.0, False
# Method of moments (Hosking & Wallis, 1987)
ratio = y_bar ** 2 / s2
xi = (ratio - 1) / 2
sigma = y_bar * (ratio + 1) / 2
if sigma <= 0 or not np.isfinite(xi) or not np.isfinite(sigma):
return 0.0, 0.0, False
return sigma, xi, True
def _gpd_sf(
x: npt.NDArray[np.float64],
sigma: float,
xi: float,
) -> npt.NDArray[np.float64]:
"""
GPD survival function P(Y > x).
Parameters
----------
x : ndarray
Values at which to evaluate.
sigma : float
Scale parameter.
xi : float
Shape parameter.
Returns
-------
ndarray
Survival function values.
"""
x = np.asarray(x, dtype=np.float64)
if abs(xi) < 1e-10:
# Exponential case (xi → 0)
return np.exp(-x / sigma)
z = 1 + xi * x / sigma
# Ensure z > 0 for valid domain
z = np.maximum(z, 0.0)
return np.power(z, -1.0 / xi)
def _anderson_darling_gpd(
exceedances: npt.NDArray[np.float64],
sigma: float,
xi: float,
) -> float:
"""
Anderson-Darling test statistic for GPD fit.
Lower values indicate better fit. A rough threshold of 2.5
indicates acceptable fit (Choulakian & Stephens, 2001).
Parameters
----------
exceedances : sorted ndarray
Exceedances in ascending order.
sigma, xi : float
GPD parameters.
Returns
-------
float
Anderson-Darling statistic.
"""
n = len(exceedances)
if n < 5:
return np.inf
# CDF values: F(y) = 1 - SF(y)
F = 1.0 - _gpd_sf(exceedances, sigma, xi)
F = np.clip(F, 1e-15, 1 - 1e-15)
i = np.arange(1, n + 1, dtype=np.float64)
A2 = -n - np.sum((2 * i - 1) * (np.log(F) + np.log(1 - F[::-1]))) / n
return A2
[docs]
def fit_gpd_tail(
null_dist: npt.NDArray[np.float64],
observed: npt.NDArray[np.float64],
n_thresholds: int = 5,
ad_threshold: float = 2.5,
) -> npt.NDArray[np.float64]:
"""
Compute p-values using GPD tail approximation.
Fits a GPD to the upper tail of the null distribution and uses
the fitted distribution to compute p-values for the observed
statistics. Falls back to empirical p-values when GPD fit is poor.
This is a tail approximation, not an exact finite-permutation
procedure.
Algorithm (Winkler et al., 2016, Section 2.2.3):
1. Set initial threshold ``u`` at upper quartile of ``null_dist``
2. Fit GPD to exceedances ``y = T_star_j - u`` by method of moments
3. Test fit with Anderson-Darling; if poor, increase ``u`` and retry
4. Compute ``p = (k / J) * GPD_sf(T - u)`` where ``k`` = #exceedances
Parameters
----------
null_dist : ndarray of shape (J,)
Max-statistic null distribution from permutations.
observed : ndarray of shape ``(*spatial_dims,)``
Observed statistics (e.g., shape (N, N) for connectivity).
n_thresholds : int, default=5
Maximum number of threshold increases to try for good GPD fit.
ad_threshold : float, default=2.5
Anderson-Darling threshold for acceptable fit.
Returns
-------
p_values : ndarray of same shape as observed
P-values. Uses GPD where fit is good, empirical otherwise. The
empirical fallback uses tie-inclusive max-statistic counting and
the Phipson-Smyth +1 correction.
"""
J = len(null_dist)
null_sorted = np.sort(null_dist)
# Empirical p-values as fallback (with +1 correction — Phipson & Smyth 2010)
count = np.sum(observed[..., np.newaxis] <= null_dist, axis=-1)
p_empirical = (count + 1.0) / (J + 1.0)
# Try increasing thresholds for GPD fit
quantiles = np.linspace(0.50, 0.90, n_thresholds)
for q in quantiles:
u = np.quantile(null_sorted, q)
exceedances = null_sorted[null_sorted > u] - u
if len(exceedances) < 10:
continue
sigma, xi, valid = _fit_gpd_mom(exceedances)
if not valid:
continue
# Check fit quality
ad_stat = _anderson_darling_gpd(exceedances, sigma, xi)
if ad_stat > ad_threshold:
continue
# Good fit found — compute p-values from GPD
k = len(exceedances) # number of exceedances
tail_fraction = k / J # P(T > u)
# For observed values > u: p = tail_fraction * GPD_sf(T - u)
# For observed values <= u: use empirical p-value
above_u = observed > u
p_gpd = np.copy(p_empirical)
if np.any(above_u):
gpd_tail_p = tail_fraction * _gpd_sf(observed[above_u] - u, sigma, xi)
# Ensure p-values are in [1/J, 1] (can't be more precise than 1/J)
gpd_tail_p = np.clip(gpd_tail_p, 1.0 / J, 1.0)
p_gpd[above_u] = gpd_tail_p
logger.debug(
f"GPD fit: u={u:.3f}, k={k}, xi={xi:.4f}, "
f"sigma={sigma:.4f}, AD={ad_stat:.3f}"
)
return p_gpd
# No good fit found — fall back to empirical
logger.debug("GPD fit failed at all thresholds, using empirical p-values")
return p_empirical
[docs]
def fit_gamma_tail(
null_dist: npt.NDArray[np.float64],
observed: npt.NDArray[np.float64],
) -> npt.NDArray[np.float64]:
"""
Compute p-values using gamma (Pearson type III) approximation.
Fits a gamma distribution to the null distribution using method
of moments (first three moments) and computes p-values from the
fitted distribution. Simpler and more robust than GPD, but still an
approximation to the finite-permutation tail.
Parameters
----------
null_dist : ndarray of shape (J,)
Max-statistic null distribution.
observed : ndarray of shape ``(*spatial_dims,)``
Observed statistics.
Returns
-------
p_values : ndarray of same shape as observed
P-values from fitted gamma distribution, or tie-inclusive
empirical p-values with the Phipson-Smyth +1 correction when
gamma fitting is not possible.
"""
from scipy import stats
J = len(null_dist)
mean = np.mean(null_dist)
var = np.var(null_dist, ddof=1)
skew_val = float(stats.skew(null_dist))
if var <= 0 or abs(skew_val) < 1e-10:
# Can't fit gamma, fall back to empirical (with +1 correction)
count = np.sum(observed[..., np.newaxis] <= null_dist, axis=-1)
return (count + 1.0) / (J + 1.0)
# Gamma parameters from moments
# shape a = 4/skew^2, scale b = sqrt(var / a), loc = mean - a*b
a = 4.0 / (skew_val ** 2)
b = np.sqrt(var / a)
loc = mean - a * b
# P-values from gamma survival function
p_values = stats.gamma.sf(observed, a, loc=loc, scale=b)
p_values = np.clip(p_values, 1.0 / J, 1.0)
return p_values
[docs]
def compute_p_values_accelerated(
emp_stat_dict: Dict[str, npt.NDArray],
max_null_dict: Dict[str, npt.NDArray],
method: str = "gpd",
) -> Dict[str, npt.NDArray]:
"""
Compute p-values using accelerated tail approximation.
Drop-in replacement for ``_compute_p_values_from_null`` that uses
GPD or gamma fitting instead of pure empirical counting. Tail
estimates are approximate; fallback empirical p-values use the same
tie-inclusive max-statistic counting as the standard path.
Parameters
----------
emp_stat_dict : dict
Dictionary with empirical statistic arrays (any keys).
max_null_dict : dict
Dictionary with null distribution arrays.
method : {'gpd', 'gamma'}, default='gpd'
Acceleration method.
Returns
-------
dict
Dictionary with p-value arrays for each key.
"""
fit_func = fit_gpd_tail if method == "gpd" else fit_gamma_tail
p_values = {}
for key in emp_stat_dict:
emp = emp_stat_dict[key]
null = max_null_dict[key]
if emp.ndim == 2:
# Standard case: (N, N) vs (J,)
p_values[key] = fit_func(null, emp)
else:
# Multi-param: (N, N, n_params) vs (J, n_params)
n_params = emp.shape[-1]
p_list = []
for p_idx in range(n_params):
p_slice = fit_func(null[:, p_idx], emp[..., p_idx])
p_list.append(p_slice)
p_values[key] = np.stack(p_list, axis=-1)
return p_values