"""
Shared enhancement wrappers for both t-test and GLM pipelines.
Each wrapper is a pure transformation: ``stat_dict → score_dict``. It takes a
per-direction statistic dict (``{'positive': ..., 'negative': ...}``) and
applies the corresponding network-based enhancement to each direction
independently.
Wrappers accept the legacy keys ``'g2>g1'`` / ``'g1>g2'`` from the v1.x
t-test pipeline (silently remapped via :func:`conninfpy._compat.normalize_keys`)
and always return a :class:`~conninfpy._compat.TailResult` with canonical
``'positive'`` / ``'negative'`` keys. The legacy keys remain readable on
the result with a :class:`DeprecationWarning` until v2.1.
Wrappers do NOT compute statistics internally — the caller is responsible for
providing the raw statistic (t-stat, β, etc.). This matches the design
pattern documented in CLAUDE.md.
"""
from __future__ import annotations
from typing import Dict, List, Union
import numpy.typing as npt
import numpy as np
from ._compat import TailResult, make_tail_result, normalize_keys
from .defaults import (
DEFAULT_EXTENT_EXPONENT,
DEFAULT_HEIGHT_EXPONENT,
DEFAULT_MIN_CLUSTER_SIZE,
DEFAULT_NBS_STAT,
DEFAULT_NBS_THRESHOLD,
DEFAULT_N_THRESHOLDS_PERMUTATION as DEFAULT_N_THRESHOLDS,
DEFAULT_START_THRESHOLD,
)
from .nbs_score import get_cnbs_score, get_nbs_score
from .tfnbs_score import (
get_fbc_tfnbs_score,
get_network_informed_tfnbs_score,
get_tfnbs_score,
)
__all__ = [
"apply_tfnbs",
"apply_nbs",
"apply_cnbs",
"apply_ni_tfnbs",
"apply_fbc_tfnbs",
]
def _wrap(stat_dict: Dict[str, npt.NDArray[np.float64]],
score_fn) -> TailResult:
"""Apply ``score_fn`` to each tail and wrap the result in a TailResult.
Accepts legacy keys via :func:`normalize_keys`; always emits canonical keys.
Falls back to a plain dict if the input has unexpected keys (e.g. F-stat
omnibus path with ``{'omnibus': ...}``).
"""
norm = normalize_keys(stat_dict)
if set(norm.keys()) == {"positive", "negative"}:
return make_tail_result(
score_fn(norm["positive"]),
score_fn(norm["negative"]),
)
return {key: score_fn(arr) for key, arr in norm.items()}
[docs]
def apply_tfnbs(
stat_dict: Dict[str, npt.NDArray[np.float64]],
e: Union[float, List[float]] = DEFAULT_EXTENT_EXPONENT,
h: Union[float, List[float]] = DEFAULT_HEIGHT_EXPONENT,
n: int = DEFAULT_N_THRESHOLDS,
start_thres: float = DEFAULT_START_THRESHOLD,
**kwargs,
) -> TailResult:
"""Apply TFNBS (threshold-free cluster enhancement) to each direction."""
return _wrap(
stat_dict,
lambda arr: get_tfnbs_score(arr, e, h, n, start_thres=start_thres),
)
[docs]
def apply_nbs(
stat_dict: Dict[str, npt.NDArray[np.float64]],
threshold: float = DEFAULT_NBS_THRESHOLD,
nbs_stat: str = DEFAULT_NBS_STAT,
**kwargs,
) -> TailResult:
"""Apply classical NBS (fixed threshold) to each direction."""
return _wrap(
stat_dict,
lambda arr: get_nbs_score(arr, threshold=threshold, stat_type=nbs_stat),
)
[docs]
def apply_cnbs(
stat_dict: Dict[str, npt.NDArray[np.float64]],
net_labels: npt.NDArray[np.int_],
**kwargs,
) -> TailResult:
"""Apply constrained NBS (block-constrained scoring) to each direction."""
return _wrap(
stat_dict,
lambda arr: get_cnbs_score(arr, net_labels),
)
[docs]
def apply_ni_tfnbs(
stat_dict: Dict[str, npt.NDArray[np.float64]],
net_labels: npt.NDArray[np.int_],
e: Union[float, List[float]] = DEFAULT_EXTENT_EXPONENT,
h: Union[float, List[float]] = DEFAULT_HEIGHT_EXPONENT,
n: int = DEFAULT_N_THRESHOLDS,
start_thres: float = DEFAULT_START_THRESHOLD,
normalization: str = "sqrt",
**kwargs,
) -> TailResult:
"""Apply network-informed TFNBS (block-density weighted) to each direction."""
return _wrap(
stat_dict,
lambda arr: get_network_informed_tfnbs_score(
arr, net_labels, e, h, n,
start_thres=start_thres, normalization=normalization,
),
)
[docs]
def apply_fbc_tfnbs(
stat_dict: Dict[str, npt.NDArray[np.float64]],
net_labels: npt.NDArray[np.int_],
e: Union[float, List[float]] = DEFAULT_EXTENT_EXPONENT,
h: Union[float, List[float]] = DEFAULT_HEIGHT_EXPONENT,
n: int = DEFAULT_N_THRESHOLDS,
start_thres: float = DEFAULT_START_THRESHOLD,
min_cluster_size: int = DEFAULT_MIN_CLUSTER_SIZE,
**kwargs,
) -> TailResult:
"""Apply functional-block-clustering TFNBS to each direction."""
return _wrap(
stat_dict,
lambda arr: get_fbc_tfnbs_score(
arr, net_labels, e, h, n,
start_thres=start_thres, min_cluster_size=min_cluster_size,
),
)