Source code for conninfpy.pairwise_stats

"""
Pairwise statistical testing module for network analysis.

This module provides functions for computing t-statistics and p-values
using permutation testing, with optional enhancement (TFNBS, NBS, cNBS,
NI-TFNBS, FBC-TFNBS) applied via shared ``apply_*`` wrappers from
``conninfpy._enhancement``.

Main Functions
--------------
compute_p_val : Compute p-values using permutation testing (with enhancement)
compute_t_stat : Compute t-statistics for paired, one-sample, or two-sample tests
compute_null_dist : Compute null distribution via permutation testing
"""

from __future__ import annotations
import os
import logging
import multiprocessing
from enum import Enum
from functools import partial
from multiprocessing import Pool
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
from scipy import stats

from .defaults import (
    DEFAULT_EXTENT_EXPONENT,
    DEFAULT_HEIGHT_EXPONENT,
    DEFAULT_N_THRESHOLDS_PERMUTATION as DEFAULT_N_THRESHOLDS,
    DEFAULT_N_PERMUTATIONS,
    DEFAULT_START_THRESHOLD,
    DEFAULT_MIN_CLUSTER_SIZE,
    DEFAULT_NBS_THRESHOLD,
    DEFAULT_NBS_STAT,
)
from ._enhancement import (
    apply_cnbs,
    apply_fbc_tfnbs,
    apply_nbs,
    apply_ni_tfnbs,
    apply_tfnbs,
)
import time

from .acceleration import compute_p_values_accelerated
from ._compat import TailResult, make_tail_result, normalize_keys
from ._progress import run_permutations
from ._result import InferenceResult, make_inference_result
from ._rng import RngLike, resolve_seed, warn_legacy_random_state


__all__ = [
    # Enums and constants
    "TestType",
    "StatMethod",
    "DEFAULT_N_PERMUTATIONS",
    "DEFAULT_EXTENT_EXPONENT",
    "DEFAULT_HEIGHT_EXPONENT",
    "DEFAULT_N_THRESHOLDS",
    # Main functions
    "compute_p_val",
    "compute_t_stat",
    "compute_null_dist",
    # T-statistic primitive (accepts pre-computed diffs)
    "compute_t_stat_diff",
    "get_available_cores",
    "_is_worker_process",
]

logger = logging.getLogger(__name__)


# =============================================================================
# Enums and Constants
# =============================================================================

[docs] class TestType(str, Enum): """Statistical test types for permutation testing.""" PAIRED = "paired" """Paired samples t-test (within-subjects design).""" ONE_SAMPLE = "one-sample" """One-sample t-test against zero.""" TWO_SAMPLE = "two-sample" """Independent samples t-test (between-subjects design)."""
[docs] class StatMethod(str, Enum): """Statistical method for network analysis.""" TFNBS = "tfnbs" """Threshold-Free Network-Based Statistics (TFCE-style).""" TSTAT = "tstat" """Raw t-statistics without enhancement.""" NBS = "nbs" """Classical Network-Based Statistics with fixed threshold.""" CNBS = "cnbs" """Constrained NBS with predefined network partitions.""" NI_TFNBS = "ni_tfnbs" """Network-Informed TFNBS with functional block density weighting.""" FBC_TFNBS = "fbc_tfnbs" """Functional Block Clustering TFNBS (block-defined clustering).""" BONFERRONI = "bonferroni" """Parametric Bonferroni correction (no permutation testing).""" BH_FDR = "bh_fdr" """Parametric Benjamini-Hochberg FDR correction (no permutation testing).""" BH_FDR_PERM = "bh_fdr_perm" """Permutation-based BH-FDR correction (per-edge permutation p-values)."""
CONSTRAINED_METHODS = {StatMethod.CNBS, StatMethod.NI_TFNBS, StatMethod.FBC_TFNBS} # Enhancement-function registry. Maps StatMethod → pure `stat_dict → stat_dict` # wrapper from ._enhancement (shared with glm_stats.compute_p_val_glm). Methods # without enhancement (raw t-stat, parametric baselines, bh_fdr_perm) map to None. _ENHANCE_MAP = { StatMethod.TSTAT: None, StatMethod.TFNBS: apply_tfnbs, StatMethod.NBS: apply_nbs, StatMethod.CNBS: apply_cnbs, StatMethod.NI_TFNBS: apply_ni_tfnbs, StatMethod.FBC_TFNBS: apply_fbc_tfnbs, StatMethod.BONFERRONI: None, StatMethod.BH_FDR: None, StatMethod.BH_FDR_PERM: None, } PARAMETRIC_METHODS = {StatMethod.BONFERRONI, StatMethod.BH_FDR} TFNBS_FAMILY = {StatMethod.TFNBS, StatMethod.NI_TFNBS, StatMethod.FBC_TFNBS} def _grid_from_kwargs( method_enum: "StatMethod", enhance_kwargs: Dict[str, Any], ) -> Tuple[Optional[npt.NDArray[np.float64]], Optional[npt.NDArray[np.float64]]]: """Return ``(e_grid, h_grid)`` arrays iff ``e``/``h`` were arrays. Used to label the parameter axis of multi-(E, H) :class:`InferenceResult` p-maps. Returns ``(None, None)`` for scalar (E, H) calls or non-TFNBS-family methods. """ if method_enum not in TFNBS_FAMILY: return None, None e = enhance_kwargs.get("e") h = enhance_kwargs.get("h") if e is None or h is None: return None, None if np.isscalar(e) and np.isscalar(h): return None, None return ( np.asarray(e, dtype=np.float64), np.asarray(h, dtype=np.float64), ) # ============================================================================= # Helper functions for permutation testing # ============================================================================= def _encode_strata(strata: Any) -> npt.NDArray[np.int_]: """Map an arbitrary stratum-label sequence to contiguous integer codes.""" arr = np.asarray(strata) _, codes = np.unique(arr, return_inverse=True) return codes.astype(np.int_) def _stratified_perm( strata: npt.NDArray[np.int_], rng: np.random.RandomState ) -> npt.NDArray[np.int_]: """Return a permutation index honoring exchangeability blocks. Within each unique stratum value, indices are independently permuted; across strata, indices do not move. This implements within-block exchangeability used by Freedman-Lane permutation on multi-site or otherwise blocked designs. Parameters ---------- strata : ndarray of shape (n,) Integer stratum codes (use :func:`_encode_strata` first). rng : numpy.random.RandomState Already-seeded RNG. Returns ------- perm : ndarray of shape (n,) Permutation index such that ``perm[i]`` is the new position of the data point originally at ``i``. Equivalent: for any array ``a``, ``a[perm]`` is the stratum-respecting permutation of ``a``. """ n = strata.shape[0] perm = np.arange(n) for s in np.unique(strata): idx = np.where(strata == s)[0] if idx.size > 1: perm[idx] = rng.permutation(idx) return perm def _stratified_choice_n1( strata: npt.NDArray[np.int_], n1_per_stratum: npt.NDArray[np.int_], rng: np.random.RandomState, ) -> npt.NDArray[np.float64]: """Length-n binary mask honoring per-stratum group-1 counts. For the two-sample fast permutation path under exchangeability blocking: within each stratum ``s``, sample exactly ``n1_per_stratum[s]`` indices to label as group 1 (the rest are group 2). Across strata, the per-stratum group totals are held fixed. Parameters ---------- strata : ndarray of shape (n,) Integer stratum codes. n1_per_stratum : ndarray of shape (n_unique_strata,) Group-1 count per stratum (computed once from the observed group labels and reused across permutations). rng : numpy.random.RandomState """ n = strata.shape[0] g = np.zeros(n, dtype=np.float64) for s_code, n1_s in enumerate(n1_per_stratum): if n1_s <= 0: continue idx = np.where(strata == s_code)[0] chosen = rng.choice(idx, size=int(n1_s), replace=False) g[chosen] = 1.0 return g def _extract_max_stats( stat_dict: Dict[str, npt.NDArray], reference_shape: Tuple[int, ...] ) -> Dict[str, np.float64]: """ Extract maximum statistics from a stat dictionary. Handles both 2D matrices and multi-parameter 3D arrays. Key-agnostic: works with any dictionary keys (e.g., 'negative'/'positive' for t-test pipeline, 'positive'/'negative' for GLM pipeline). Parameters ---------- stat_dict : dict Dictionary with statistic arrays (any string keys). reference_shape : tuple Shape of a single sample for determining dimensionality. Returns ------- dict Dictionary with maximum statistics for each key. """ result: Dict[str, Any] = {} for key, arr in stat_dict.items(): if arr.shape == reference_shape: result[key] = np.max(arr).astype(np.float64) else: # Multi-parameter case: max over spatial dims, keep param dim spatial_axes = tuple(range(arr.ndim - 1)) result[key] = np.max(arr, axis=spatial_axes).astype(np.float64) if set(result.keys()) == {"positive", "negative"}: return make_tail_result(result["positive"], result["negative"]) return result def _permutation_task_ind( full_group: npt.NDArray[np.float64], func: Callable[..., Any], n1: int, seed: int, **func_kwargs ) -> Dict[str, Union[float, npt.NDArray[np.float64]]]: """ Compute maximum t-statistic for a single permutation (two-sample test). Parameters ---------- full_group : ndarray Concatenated data array of shape (n_samples_1 + n_samples_2, *dims). func : callable Function to compute the t-statistic. n1 : int Number of samples in group 1. seed : int Random seed for this permutation. **func_kwargs Additional keyword arguments passed to func. Returns ------- dict Dictionary with max statistics for 'negative' and 'positive' directions. """ rng = np.random.RandomState(seed) idx = rng.permutation(full_group.shape[0]) new_group1 = full_group[idx[:n1]] new_group2 = full_group[idx[n1:]] perm_stat_dict = func(new_group1, new_group2, test_type='two-sample', **func_kwargs) return _extract_max_stats(perm_stat_dict, full_group[0].shape) def _permutation_task_paired( diffs: npt.NDArray[np.float64], func: Callable[..., Any], seed: Optional[int] = None, **func_kwargs ) -> Dict[str, Union[float, npt.NDArray[np.float64]]]: """ Compute maximum t-statistic for a single permutation (paired/one-sample test). Parameters ---------- diffs : ndarray Array of shape (n_samples, *dims) containing paired differences. func : callable Function to compute the t-statistic. seed : int, optional Random seed for this permutation. **func_kwargs Additional keyword arguments passed to func. Returns ------- dict Dictionary with max statistics for 'negative' and 'positive' directions. """ n_dims = len(diffs.shape) - 1 faked_dims = [1] * n_dims rng = np.random.RandomState(seed) signs = rng.choice([1, -1], diffs.shape[0]).reshape(-1, *faked_dims) new_diffs = signs * diffs perm_stat_dict = func(new_diffs, **func_kwargs) return _extract_max_stats(perm_stat_dict, diffs[0].shape) def _collect_results_to_arrays( results: List[Dict[str, Any]], n_permutations: int ) -> Dict[str, npt.NDArray[np.float64]]: """ Efficiently collect permutation results into numpy arrays. Uses pre-allocation and direct indexing instead of Python loops. Parameters ---------- results : list of dict List of permutation result dictionaries. n_permutations : int Number of permutations. Returns ------- dict Dictionary with arrays of shape (n_permutations,) or (n_permutations, n_params). """ group_keys = list(results[0].keys()) first_val = results[0][group_keys[0]] output_shape = first_val.shape if hasattr(first_val, 'shape') else () t_maxes_dict = { key: np.empty((n_permutations, *output_shape), dtype=np.float64) for key in group_keys } for key in group_keys: for i, perm_dict in enumerate(results): t_maxes_dict[key][i] = perm_dict[key] if set(t_maxes_dict.keys()) == {"positive", "negative"}: return make_tail_result(t_maxes_dict["positive"], t_maxes_dict["negative"]) return t_maxes_dict # ============================================================================= # Pre-computed sums fast path (method='tstat' and 'bh_fdr_perm') # ============================================================================= # Exploits the algebraic variance identity Var(x) = (Σx² − (Σx)²/n)/(n−1) with # the sign-flip-invariant Σx² pre-computed once, so each permutation reduces to # a matrix-vector dot product over upper-triangle edges (~6-8x faster per perm # vs. recomputing mean/std from full (nSub, N, N) arrays). def _onesample_tstat_from_sums(sum_vec, sumsq_vec, n): """One-sample t-statistic from pre-computed sums. Returns (t_pos, t_neg) edge vectors.""" mean = sum_vec / n var = np.maximum((sumsq_vec - sum_vec ** 2 / n) / (n - 1), 0) se = np.sqrt(var / n) with np.errstate(divide='ignore', invalid='ignore'): t = mean / se t = np.where(se == 0, 0.0, t) return np.maximum(t, 0.0), np.maximum(-t, 0.0) def _twosample_tstat_from_sums(sum1, sumsq1, n1, sum2, sumsq2, n2): """Welch two-sample t-statistic from pre-computed sums. Returns (t_pos, t_neg) edge vectors.""" mean1 = sum1 / n1 mean2 = sum2 / n2 var1 = np.maximum((sumsq1 - sum1 ** 2 / n1) / (n1 - 1), 0) var2 = np.maximum((sumsq2 - sum2 ** 2 / n2) / (n2 - 1), 0) denom = np.sqrt(var1 / n1 + var2 / n2) with np.errstate(divide='ignore', invalid='ignore'): t = (mean2 - mean1) / denom t = np.where(denom == 0, 0.0, t) return np.maximum(t, 0.0), np.maximum(-t, 0.0) def _precompute_edge_sums(data_3d): """Extract upper-triangle edges + sum-of-squares for sign-flip fast path. Returns (X, sumsq_all): X is (n_samples, n_edges) contiguous; sumsq_all is (n_edges,) and is sign-flip invariant since s² = 1 for s ∈ {+1, −1}. """ N = data_3d.shape[1] triu_idx = np.triu_indices(N, k=1) X = np.ascontiguousarray(data_3d[:, triu_idx[0], triu_idx[1]]) sumsq_all = np.sum(X ** 2, axis=0) return X, sumsq_all def _precompute_twosample_sums(Xall_3d): """Extract upper-triangle edges + pooled sums for group-label fast path. Returns (Xall, Xall2, sum_all, sumsq_all). Xall2 caches element-wise square so per-permutation cost is two matvecs + two subtractions. """ N = Xall_3d.shape[1] triu_idx = np.triu_indices(N, k=1) Xall = np.ascontiguousarray(Xall_3d[:, triu_idx[0], triu_idx[1]]) Xall2 = Xall ** 2 sum_all = np.sum(Xall, axis=0) sumsq_all = np.sum(Xall2, axis=0) return Xall, Xall2, sum_all, sumsq_all def _fast_permutation_task_paired(X, sumsq_all, seed): """Paired/one-sample permutation via sign-flip + sums. Returns max stats.""" rng = np.random.RandomState(seed) n_sub = X.shape[0] signs = rng.choice([1, -1], n_sub).astype(np.float64) sum_perm = signs @ X t_pos, t_neg = _onesample_tstat_from_sums(sum_perm, sumsq_all, n_sub) return { "positive": np.float64(np.max(t_pos)), "negative": np.float64(np.max(t_neg)), } def _fast_permutation_task_ind(Xall, Xall2, sum_all, sumsq_all, n1, seed, strata=None, n1_per_stratum=None): """Two-sample permutation via group-label shuffle + sums. Returns max stats. When ``strata`` is provided (with ``n1_per_stratum``), the group-label draw honors per-stratum group totals — exchangeability blocking. """ rng = np.random.RandomState(seed) n_all = Xall.shape[0] n2 = n_all - n1 if strata is None: g = np.zeros(n_all, dtype=np.float64) g[rng.choice(n_all, n1, replace=False)] = 1.0 else: g = _stratified_choice_n1(strata, n1_per_stratum, rng) sum1 = g @ Xall sumsq1 = g @ Xall2 sum2 = sum_all - sum1 sumsq2 = sumsq_all - sumsq1 t_pos, t_neg = _twosample_tstat_from_sums(sum1, sumsq1, n1, sum2, sumsq2, n2) return { "positive": np.float64(np.max(t_pos)), "negative": np.float64(np.max(t_neg)), } def _fast_permutation_task_paired_edges(X, sumsq_all, seed): """Paired/one-sample permutation returning upper-triangle t-stat vectors (for BH-FDR-perm).""" rng = np.random.RandomState(seed) n_sub = X.shape[0] signs = rng.choice([1, -1], n_sub).astype(np.float64) sum_perm = signs @ X t_pos, t_neg = _onesample_tstat_from_sums(sum_perm, sumsq_all, n_sub) return {"positive": t_pos, "negative": t_neg} def _fast_permutation_task_ind_edges(Xall, Xall2, sum_all, sumsq_all, n1, seed, strata=None, n1_per_stratum=None): """Two-sample permutation returning upper-triangle t-stat vectors (for BH-FDR-perm).""" rng = np.random.RandomState(seed) n_all = Xall.shape[0] n2 = n_all - n1 if strata is None: g = np.zeros(n_all, dtype=np.float64) g[rng.choice(n_all, n1, replace=False)] = 1.0 else: g = _stratified_choice_n1(strata, n1_per_stratum, rng) sum1 = g @ Xall sumsq1 = g @ Xall2 sum2 = sum_all - sum1 sumsq2 = sumsq_all - sumsq1 t_pos, t_neg = _twosample_tstat_from_sums(sum1, sumsq1, n1, sum2, sumsq2, n2) return {"positive": t_pos, "negative": t_neg} # ============================================================================= # Enhancement-method fast permutation tasks (sums trick + shared enhancers) # ============================================================================= def _tstat_edges_to_matrix(t_pos, t_neg, triu_idx, N): """Expand upper-triangle t-stat vectors to symmetric (N, N) matrices.""" t_pos_mat = np.zeros((N, N)) t_pos_mat[triu_idx] = t_pos t_pos_mat = t_pos_mat + t_pos_mat.T t_neg_mat = np.zeros((N, N)) t_neg_mat[triu_idx] = t_neg t_neg_mat = t_neg_mat + t_neg_mat.T return t_pos_mat, t_neg_mat def _fast_perm_task_enhanced_paired( X, sumsq_all, triu_idx, N, enhance_func, enhance_kwargs, seed, ): """Paired/one-sample permutation: sign-flip + fast t-stat + enhancement → max.""" rng = np.random.RandomState(seed) n_sub = X.shape[0] signs = rng.choice([1, -1], n_sub).astype(np.float64) sum_perm = signs @ X t_pos, t_neg = _onesample_tstat_from_sums(sum_perm, sumsq_all, n_sub) t_pos_mat, t_neg_mat = _tstat_edges_to_matrix(t_pos, t_neg, triu_idx, N) stat_dict = {"positive": t_pos_mat, "negative": t_neg_mat} if enhance_func is not None: stat_dict = enhance_func(stat_dict, **enhance_kwargs) return _extract_max_stats(stat_dict, (N, N)) def _fast_perm_task_enhanced_ind( Xall, Xall2, sum_all, sumsq_all, n1, triu_idx, N, enhance_func, enhance_kwargs, seed, strata=None, n1_per_stratum=None, ): """Two-sample permutation: group shuffle + fast t-stat + enhancement → max. When ``strata`` is provided, the group-label draw is stratified (per-stratum group totals held fixed). """ rng = np.random.RandomState(seed) n_all = Xall.shape[0] n2 = n_all - n1 if strata is None: g = np.zeros(n_all, dtype=np.float64) g[rng.choice(n_all, n1, replace=False)] = 1.0 else: g = _stratified_choice_n1(strata, n1_per_stratum, rng) sum1 = g @ Xall sumsq1 = g @ Xall2 sum2 = sum_all - sum1 sumsq2 = sumsq_all - sumsq1 t_pos, t_neg = _twosample_tstat_from_sums(sum1, sumsq1, n1, sum2, sumsq2, n2) t_pos_mat, t_neg_mat = _tstat_edges_to_matrix(t_pos, t_neg, triu_idx, N) stat_dict = {"positive": t_pos_mat, "negative": t_neg_mat} if enhance_func is not None: stat_dict = enhance_func(stat_dict, **enhance_kwargs) return _extract_max_stats(stat_dict, (N, N)) # ============================================================================= # Helper functions for multiprocessing # ============================================================================= def _is_worker_process() -> bool: """Check if running inside a multiprocessing worker. Returns True when the current process was spawned by a Pool, preventing nested Pool creation which causes deadlocks and unintentional process multiplication. """ proc_name = multiprocessing.current_process().name return ( proc_name != 'MainProcess' or 'PoolWorker' in proc_name or 'SpawnPoolWorker' in proc_name or 'ForkPoolWorker' in proc_name )
[docs] def get_available_cores(): try: # Linux affinity = os.sched_getaffinity(0) return len(affinity) except AttributeError: # Fallback Windows/Mac return multiprocessing.cpu_count()
# ============================================================================= # Null distribution computation # =============================================================================
[docs] def compute_null_dist( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]] = None, func: Optional[Callable[..., Any]] = None, n_permutations: int = DEFAULT_N_PERMUTATIONS, test_type: Union[str, TestType] = TestType.PAIRED, random_state: Optional[int] = None, n_processes: Optional[int] = None, use_mp: bool = True, verbose: bool = False, strata: Optional[npt.NDArray[Any]] = None, **func_kwargs ) -> Dict[str, npt.NDArray[np.float64]]: """ Compute null distribution of maximum t-statistics via permutation testing. Optimized implementation with: - Fixed indexing bug in sequential mode - Efficient result collection - Context-aware multiprocessing (auto-disables inside worker processes) Parameters ---------- group1 : ndarray of shape ``(n_samples_1, *dims)`` Data array for group 1. group2 : ndarray of shape ``(n_samples_2, *dims)``, optional Data array for group 2. Required for 'paired' and 'two-sample' tests. func : callable, optional Function to compute the t-statistic. n_permutations : int, default=1000 Number of permutations. test_type : {'paired', 'one-sample', 'two-sample'} or TestType Type of statistical test. random_state : int, optional Seed for reproducibility. n_processes : int, optional Number of parallel processes. Defaults to CPU count. use_mp : bool, default=True Whether to use multiprocessing. Automatically disabled when called from inside a multiprocessing worker to prevent nested pools. **func_kwargs Additional keyword arguments passed to func. Returns ------- dict Dictionary with 'negative' and 'positive' arrays of shape (n_permutations,) or (n_permutations, n_params) for multi-parameter TFCE. Raises ------ ValueError If test_type is invalid or required group2 is missing. """ # Normalize test_type to string for comparison test_type_str = test_type.value if isinstance(test_type, TestType) else test_type # Input validation if test_type_str in (TestType.PAIRED.value, TestType.TWO_SAMPLE.value): if group2 is None: raise ValueError(f"group2 is required for test_type='{test_type_str}'") if group1.shape[1:] != group2.shape[1:]: raise ValueError("Trailing dimensions of group1 and group2 must match.") n1, n2 = group1.shape[0], group2.shape[0] if n1 < 2 or n2 < 2: raise ValueError("Each group must have at least 2 samples.") if n_permutations < 1: raise ValueError("n_permutations must be at least 1.") # Fast path: raw t-statistic with no enhancement — uses pre-computed sums # trick (upper-triangle only, sign-flip-invariant Σx²). ~6-8x faster per perm. use_fast_tstat = ( func in (compute_t_stat_diff, _compute_t_stat_ind, compute_t_stat) and not func_kwargs and group1.ndim == 3 and group1.shape[1] == group1.shape[2] ) # Prepare data for permutation if test_type_str == TestType.PAIRED.value: if use_fast_tstat: diffs = group2 - group1 X, sumsq_all = _precompute_edge_sums(diffs) task_func = partial(_fast_permutation_task_paired, X, sumsq_all) else: array_to_permute = group2 - group1 task_func = partial(_permutation_task_paired, array_to_permute, func, **func_kwargs) elif test_type_str == TestType.TWO_SAMPLE.value: strata_codes = None n1_per_stratum = None if strata is not None: strata_codes = _encode_strata(strata) if strata_codes.shape[0] != n1 + n2: raise ValueError( f"strata length {strata_codes.shape[0]} does not match " f"n1+n2={n1+n2} for the two-sample test." ) # Observed group-1 counts per stratum (first n1 entries are group 1 # by construction of the concatenated stack). n1_per_stratum = np.bincount( strata_codes[:n1], minlength=int(strata_codes.max()) + 1 ) if use_fast_tstat: Xall_3d = np.concatenate((group1, group2), axis=0) Xall, Xall2, sum_all, sumsq_all = _precompute_twosample_sums(Xall_3d) task_func = partial( _fast_permutation_task_ind, Xall, Xall2, sum_all, sumsq_all, n1, strata=strata_codes, n1_per_stratum=n1_per_stratum, ) else: if strata is not None: raise NotImplementedError( "strata= is only wired into the fast t-stat path for " "two-sample tests. Pass one of the canonical t-stat " "functions (compute_t_stat / _compute_t_stat_ind) to " "exercise the stratified-permutation path." ) array_to_permute = np.concatenate((group1, group2), axis=0) task_func = partial(_permutation_task_ind, array_to_permute, func, n1, **func_kwargs) elif test_type_str == TestType.ONE_SAMPLE.value: if use_fast_tstat: X, sumsq_all = _precompute_edge_sums(group1) task_func = partial(_fast_permutation_task_paired, X, sumsq_all) else: array_to_permute = group1 task_func = partial(_permutation_task_paired, array_to_permute, func, **func_kwargs) else: raise ValueError( f"Invalid test_type: '{test_type_str}'. " f"Must be one of: {[t.value for t in TestType]}" ) # Generate seeds for reproducibility rng = np.random.RandomState(random_state) seeds = rng.randint(0, 2**32 - 1, size=n_permutations, dtype=np.int64) _use_mp = use_mp and not _is_worker_process() if _use_mp and n_processes is None: n_processes = get_available_cores() if _use_mp and n_processes is not None: n_processes = min(n_processes, n_permutations) results = run_permutations( task_func, list(seeds), use_mp=_use_mp, n_processes=n_processes, verbose=verbose, desc="compute_null_dist", ) return _collect_results_to_arrays(results, n_permutations)
def _compute_enhanced_null_dist( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]], test_type_str: str, enhance_func: Callable, enhance_kwargs: Dict[str, Any], n_permutations: int = DEFAULT_N_PERMUTATIONS, random_state: Optional[int] = None, n_processes: Optional[int] = None, use_mp: bool = True, verbose: bool = False, strata: Optional[npt.NDArray[Any]] = None, ) -> Dict[str, npt.NDArray[np.float64]]: """Null distribution for enhancement methods via pre-computed sums fast path. Each permutation: sign-flip or group-shuffle → fast t-stat via sums → reconstruct (N, N) → apply shared enhancement wrapper → extract max. When ``strata`` is provided, the two-sample group-shuffle path honors per-stratum group totals (exchangeability blocking). Sign- flip paths are stratum-invariant by construction; ``strata`` is silently accepted but has no effect there. """ N = group1.shape[1] triu_idx = np.triu_indices(N, k=1) if test_type_str == TestType.PAIRED.value: diffs = group2 - group1 X, sumsq_all = _precompute_edge_sums(diffs) task_func = partial( _fast_perm_task_enhanced_paired, X, sumsq_all, triu_idx, N, enhance_func, enhance_kwargs, ) elif test_type_str == TestType.ONE_SAMPLE.value: X, sumsq_all = _precompute_edge_sums(group1) task_func = partial( _fast_perm_task_enhanced_paired, X, sumsq_all, triu_idx, N, enhance_func, enhance_kwargs, ) elif test_type_str == TestType.TWO_SAMPLE.value: n1 = group1.shape[0] n_all = n1 + group2.shape[0] strata_codes = None n1_per_stratum = None if strata is not None: strata_codes = _encode_strata(strata) if strata_codes.shape[0] != n_all: raise ValueError( f"strata length {strata_codes.shape[0]} does not match " f"n1+n2={n_all} for the two-sample test." ) n1_per_stratum = np.bincount( strata_codes[:n1], minlength=int(strata_codes.max()) + 1 ) Xall_3d = np.concatenate((group1, group2), axis=0) Xall, Xall2, sum_all, sumsq_all = _precompute_twosample_sums(Xall_3d) task_func = partial( _fast_perm_task_enhanced_ind, Xall, Xall2, sum_all, sumsq_all, n1, triu_idx, N, enhance_func, enhance_kwargs, strata=strata_codes, n1_per_stratum=n1_per_stratum, ) else: raise ValueError( f"Invalid test_type: '{test_type_str}'. " f"Must be one of: {[t.value for t in TestType]}" ) rng = np.random.RandomState(random_state) seeds = rng.randint(0, 2**32 - 1, size=n_permutations, dtype=np.int64) _use_mp = use_mp and not _is_worker_process() if _use_mp and n_processes is None: n_processes = get_available_cores() if _use_mp and n_processes is not None: n_processes = min(n_processes, n_permutations) results = run_permutations( task_func, list(seeds), use_mp=_use_mp, n_processes=n_processes, verbose=verbose, desc="enhanced_null", ) return _collect_results_to_arrays(results, n_permutations) def _compute_bh_fdr_perm_p_values( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]], test_type_str: str, n_permutations: int, random_state: Optional[int], use_mp: bool, n_processes: Optional[int], verbose: bool = False, strata: Optional[npt.NDArray[Any]] = None, ) -> Tuple[Dict[str, npt.NDArray[np.float64]], Dict[str, npt.NDArray[np.float64]]]: """ Compute BH-FDR corrected p-values using permutation-based per-edge nulls. 1. Compute observed t-stats per edge. 2. For each permutation, compute per-edge t-stats. 3. For each edge, p = (# perm t >= observed t) / n_perm. 4. Apply BH-FDR correction to per-edge p-values. """ # Compute observed t-stats if test_type_str == TestType.PAIRED.value: diffs = group2 - group1 emp_t_dict = compute_t_stat_diff(diffs) elif test_type_str == TestType.ONE_SAMPLE.value: emp_t_dict = compute_t_stat_diff(group1) elif test_type_str == TestType.TWO_SAMPLE.value: emp_t_dict = compute_t_stat(group1, group2, test_type=test_type_str) else: raise ValueError(f"Invalid test_type: '{test_type_str}'") N = emp_t_dict["positive"].shape[0] triu_idx = np.triu_indices(N, k=1) n_edges = len(triu_idx[0]) # Generate seeds rng = np.random.RandomState(random_state) seeds = rng.randint(0, 2**32 - 1, size=n_permutations, dtype=np.int64) # Prepare permutation task — pre-computed sums fast path (edge vectors). # Skips the redundant (N, N) reshape and triu extraction inside the perm loop. if test_type_str == TestType.PAIRED.value: diffs = group2 - group1 X, sumsq_all = _precompute_edge_sums(diffs) task_func = partial(_fast_permutation_task_paired_edges, X, sumsq_all) elif test_type_str == TestType.ONE_SAMPLE.value: X, sumsq_all = _precompute_edge_sums(group1) task_func = partial(_fast_permutation_task_paired_edges, X, sumsq_all) else: # two-sample Xall_3d = np.concatenate((group1, group2), axis=0) Xall, Xall2, sum_all, sumsq_all = _precompute_twosample_sums(Xall_3d) n1 = group1.shape[0] strata_codes = None n1_per_stratum = None if strata is not None: strata_codes = _encode_strata(strata) n_all = Xall_3d.shape[0] if strata_codes.shape[0] != n_all: raise ValueError( f"strata length {strata_codes.shape[0]} does not match " f"n1+n2={n_all} for the two-sample test." ) n1_per_stratum = np.bincount( strata_codes[:n1], minlength=int(strata_codes.max()) + 1 ) task_func = partial( _fast_permutation_task_ind_edges, Xall, Xall2, sum_all, sumsq_all, n1, strata=strata_codes, n1_per_stratum=n1_per_stratum, ) _use_mp = use_mp and not _is_worker_process() if _use_mp and n_processes is None: n_processes = get_available_cores() if _use_mp and n_processes is not None: n_processes = min(n_processes, n_permutations) perm_results = run_permutations( task_func, list(seeds), use_mp=_use_mp, n_processes=n_processes, verbose=verbose, desc="bh_fdr_perm", ) # Compute per-edge p-values (with +1 correction — Phipson & Smyth 2010) # and apply BH-FDR correction p_values = {} for key in ("positive", "negative"): emp_upper = emp_t_dict[key][triu_idx] # Count how many permutation t-stats >= observed for each edge count_ge = np.zeros(n_edges, dtype=np.float64) for perm_dict in perm_results: # perm_dict[key] is already the upper-triangle vector (fast path) count_ge += (perm_dict[key] >= emp_upper).astype(np.float64) # +1 correction: prevents p = 0 with finite permutations per_edge_p = (count_ge + 1.0) / (n_permutations + 1.0) # Apply BH-FDR correction corrected_p = _bh_fdr_correction(per_edge_p) # Reconstruct full symmetric matrix p_mat = np.ones((N, N), dtype=np.float64) p_mat[triu_idx] = corrected_p p_mat[(triu_idx[1], triu_idx[0])] = corrected_p p_values[key] = p_mat return p_values, emp_t_dict # ============================================================================= # P-value computation # ============================================================================= def _compute_p_values_from_null( emp_t_dict: Dict[str, npt.NDArray], max_null_dict: Dict[str, npt.NDArray] ) -> Dict[str, npt.NDArray]: """ Compute max-statistic p-values from empirical statistics and a null. Ties are counted in the conservative direction (``null >= observed``) and the Phipson-Smyth +1 correction is applied: ``p = (count + 1) / (B + 1)``. This keeps an all-zero observed map against an all-zero null at ``p = 1`` rather than ``1 / (B + 1)``. Parameters ---------- emp_t_dict : dict Dictionary with empirical t-statistic arrays. max_null_dict : dict Dictionary with null distribution arrays. Returns ------- dict Dictionary with p-value arrays for each direction. """ keys = list(emp_t_dict.keys()) p_values = {} is_2d = len(emp_t_dict[keys[0]].shape) == 2 for key in keys: emp_t = emp_t_dict[key] null_dist = max_null_dict[key] if is_2d: # Shape: (N, N) vs (n_permutations,) n_perm = null_dist.shape[0] emp_t_expanded = emp_t[..., np.newaxis] count = np.sum(emp_t_expanded <= null_dist, axis=-1) else: # Multi-param: (N, N, n_params) vs (n_permutations, n_params) n_perm = null_dist.shape[0] emp_t_expanded = emp_t[..., np.newaxis] null_reshaped = null_dist.swapaxes(0, 1)[None, None, ...] count = np.sum(emp_t_expanded <= null_reshaped, axis=-1) # +1 correction (Phipson & Smyth 2010): prevents p = 0 with finite perms p_values[key] = (count + 1.0) / (n_perm + 1.0) return p_values def _compute_degrees_of_freedom( n1: int, n2: int, test_type_str: str ) -> int: """Compute degrees of freedom for a t-test. Parameters ---------- n1 : int Number of samples in group 1. n2 : int Number of samples in group 2 (0 for one-sample). test_type_str : str Test type string ('paired', 'one-sample', 'two-sample'). Returns ------- int Degrees of freedom. """ if test_type_str == TestType.TWO_SAMPLE.value: return n1 + n2 - 2 else: # paired or one-sample: df = n - 1 return n1 - 1 def _compute_welch_degrees_of_freedom( group1: npt.NDArray[np.float64], group2: npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """Compute edge-wise Welch-Satterthwaite degrees of freedom.""" n1 = group1.shape[0] n2 = group2.shape[0] var1 = np.var(group1, axis=0, ddof=1) var2 = np.var(group2, axis=0, ddof=1) se1 = var1 / n1 se2 = var2 / n2 numerator = (se1 + se2) ** 2 denominator = (se1 ** 2) / (n1 - 1) + (se2 ** 2) / (n2 - 1) pooled_df = float(n1 + n2 - 2) with np.errstate(divide="ignore", invalid="ignore"): df = numerator / denominator return np.where(denominator > 0, df, pooled_df) def _compute_parametric_p_values( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]], test_type_str: str, method_enum: "StatMethod", alpha: float = 0.05, ) -> Tuple[Dict[str, npt.NDArray[np.float64]], Dict[str, npt.NDArray[np.float64]]]: """ Compute parametric p-values with Bonferroni or BH-FDR correction. Parameters ---------- group1 : ndarray of shape (n_subjects, N, N) Connectivity matrices for group 1. group2 : ndarray of shape (n_subjects, N, N), optional Connectivity matrices for group 2. test_type_str : str Test type ('paired', 'one-sample', 'two-sample'). method_enum : StatMethod BONFERRONI or BH_FDR. alpha : float, default=0.05 Significance level (used for BH-FDR step-up). Returns ------- (p_values, t_dict) : tuple of dict ``p_values`` — per-tail corrected p-value arrays of shape (N, N). ``t_dict`` — per-tail observed one-tail-clipped (non-negative) t-statistic maps, used to populate ``stat_positive``/ ``stat_negative`` on :class:`InferenceResult`. """ # Compute t-statistics using existing infrastructure if test_type_str == TestType.PAIRED.value: diffs = group2 - group1 t_dict = compute_t_stat_diff(diffs) df = _compute_degrees_of_freedom(group1.shape[0], 0, test_type_str) elif test_type_str == TestType.ONE_SAMPLE.value: t_dict = compute_t_stat_diff(group1) df = _compute_degrees_of_freedom(group1.shape[0], 0, test_type_str) elif test_type_str == TestType.TWO_SAMPLE.value: t_dict = _compute_t_stat_ind(group1, group2) df = _compute_welch_degrees_of_freedom(group1, group2) else: raise ValueError(f"Invalid test_type: '{test_type_str}'") p_values = {} for key in ("negative", "positive"): t_vals = t_dict[key] # Non-negative (one-tailed) N = t_vals.shape[0] # Compute one-tailed p-values from t-distribution raw_p = stats.t.sf(t_vals, df) # Extract upper triangle (unique edges for symmetric matrices) triu_idx = np.triu_indices(N, k=1) p_upper = raw_p[triu_idx] # Number of unique comparisons m = len(p_upper) if method_enum == StatMethod.BONFERRONI: # Bonferroni: multiply by number of comparisons, cap at 1.0 p_corrected_upper = np.minimum(p_upper * m, 1.0) elif method_enum == StatMethod.BH_FDR: # Benjamini-Hochberg step-up procedure p_corrected_upper = _bh_fdr_correction(p_upper) else: raise ValueError(f"Unsupported parametric method: {method_enum}") # Reconstruct full symmetric matrix p_corrected = np.ones((N, N), dtype=np.float64) p_corrected[triu_idx] = p_corrected_upper p_corrected[(triu_idx[1], triu_idx[0])] = p_corrected_upper p_values[key] = p_corrected return p_values, t_dict def _bh_fdr_correction( p_values: npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """ Apply Benjamini-Hochberg FDR correction to a 1D array of p-values. Parameters ---------- p_values : ndarray of shape (m,) Uncorrected p-values. Returns ------- ndarray of shape (m,) BH-corrected p-values (adjusted so that thresholding at alpha controls FDR at level alpha). """ m = len(p_values) if m == 0: return p_values.copy() # Sort p-values and track original indices sorted_idx = np.argsort(p_values) sorted_p = p_values[sorted_idx] # BH adjustment: p_adj[i] = p[i] * m / rank[i] ranks = np.arange(1, m + 1, dtype=np.float64) adjusted = sorted_p * m / ranks # Enforce monotonicity (step-up): working backwards, # each adjusted p-value must be <= the next one adjusted = np.minimum.accumulate(adjusted[::-1])[::-1] # Cap at 1.0 adjusted = np.minimum(adjusted, 1.0) # Restore original order result = np.empty(m, dtype=np.float64) result[sorted_idx] = adjusted return result
[docs] def compute_p_val( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]] = None, n_permutations: int = DEFAULT_N_PERMUTATIONS, test_type: Union[str, TestType] = TestType.PAIRED, method: Union[str, StatMethod] = StatMethod.TFNBS, use_mp: bool = True, rng: RngLike = None, random_state: Optional[int] = None, # deprecated alias for `rng` n_processes: Optional[int] = None, acceleration: Optional[str] = None, verbose: bool = False, # Method-specific parameters net_labels: Optional[npt.NDArray[np.int_]] = None, threshold: float = DEFAULT_NBS_THRESHOLD, nbs_stat: str = DEFAULT_NBS_STAT, e: Union[float, List[float]] = DEFAULT_EXTENT_EXPONENT, h: Union[float, List[float]] = DEFAULT_HEIGHT_EXPONENT, n: int = DEFAULT_N_THRESHOLDS, start_thres: float = DEFAULT_START_THRESHOLD, min_cluster_size: int = DEFAULT_MIN_CLUSTER_SIZE, normalization: str = "sqrt", strata: Optional[npt.NDArray[Any]] = None, **kwargs ) -> InferenceResult: """ Compute p-values using permutation testing with various network-based methods. .. warning:: **Welch + group-label permutation under unequal variances** The two-sample path uses Welch's t (unequal-variance) unconditionally. Under unequal variances with unbalanced group sizes the exchangeability assumption underlying permutation breaks and Type I error inflates (Anderson & Robinson 2001; Hayes 2000). A pooled-variance option will be added in v2.2 with an explicit ``var_equal=True|False`` knob; until then, treat two-sample results with caution when both group sizes and variances differ substantially. Supports multiple statistical methods: tfnbs, tstat, nbs, cnbs, ni_tfnbs, fbc_tfnbs, bonferroni, bh_fdr, bh_fdr_perm. .. note:: **cNBS null distribution** This implementation uses the **max-statistic** null distribution for cNBS, where each permutation contributes its global maximum cNBS score to the null. This is the same family-wise error rate (FWER) control strategy used by classical NBS (Zalesky et al. 2010) and TFNBS. Noble & Scheinost (2020) originally proposed computing **per-block** null distributions with Bonferroni correction across blocks. The max-statistic approach used here is more conservative (controls FWER globally) but provides a consistent framework across all methods in this package. .. note:: **Bonferroni and BH-FDR methods** ``StatMethod.BONFERRONI`` and ``StatMethod.BH_FDR`` are parametric baselines that do **not** use permutation testing. They compute p-values from the t-distribution and apply multiple comparison corrections. The ``n_permutations`` parameter is ignored for these methods. In the two-sample path, the parametric baseline uses edge-wise Welch-Satterthwaite degrees of freedom to match the Welch statistic. .. note:: **Permutation p-values** Empirical max-statistic p-values count ties conservatively (``null >= observed``) and use the Phipson-Smyth +1 correction: ``p = (count + 1) / (B + 1)``. Parameters ---------- group1 : ndarray of shape (n_subjects_g1, N, N) Input connectivity matrices for group 1. group2 : ndarray of shape (n_subjects_g2, N, N), optional Input connectivity matrices for group 2. Required for paired/two-sample tests. n_permutations : int, default=1000 Number of permutations for null distribution. test_type : {'paired', 'one-sample', 'two-sample'} or TestType, default='paired' Type of statistical test. method : {'tfnbs', 'tstat', 'nbs', 'cnbs', 'ni_tfnbs', 'fbc_tfnbs'} or StatMethod, default='tfnbs' Statistical method to use for scoring. use_mp : bool, default=True Use multiprocessing for permutation testing. Automatically disabled when called from inside a multiprocessing worker to prevent deadlocks. random_state : int, optional Random seed for reproducibility. n_processes : int, optional Number of CPU cores for parallel computing. net_labels : ndarray of shape (N,), optional Network labels for each node. Required for cnbs, ni_tfnbs, and fbc_tfnbs. threshold : float, default=2.0 T-statistic threshold for NBS (only used when method='nbs'). nbs_stat : {'extent', 'intensity'}, default='extent' Cluster statistic for NBS (only used when method='nbs'). e : float or sequence of float, default=0.3 Extent exponent for TFNBS-based methods. Pass an equal-length sequence with ``h`` to evaluate a whole ``(E, H)`` grid in one permutation pass — TFNBS's threshold integration runs once and broadcasts the per-cell exponentiation, so a K-cell grid costs ~the same wall-clock as a single cell. In grid mode the returned :class:`~conninfpy.InferenceResult` carries the parameter axis (``result.is_grid``, ``result.e_grid``, ``result.h_grid``); use ``result.select(param_idx)`` or pass ``param_idx=`` to ``significant_edges`` / ``to_csv`` to project to a single cell. h : float or sequence of float, default=3.0 Height exponent for TFNBS-based methods. See ``e`` for grid mode. n : int, default=10 Number of threshold steps for TFNBS-based methods. start_thres : float, default=1.65 Starting threshold for TFNBS integration. min_cluster_size : int, default=3 Minimum cluster size for FBC-TFNBS (only used when method='fbc_tfnbs'). normalization : {'sqrt', 'linear', 'none'}, default='sqrt' Block density normalization for NI-TFNBS (only used when method='ni_tfnbs'). strata : array-like of shape (n,), optional Exchangeability-block labels per subject (e.g. site IDs after ComBat harmonization). When provided, the two-sample group-label permutation is restricted to *within-stratum* swaps, holding per-stratum group totals fixed. This prevents the shadow-of-H₀ leak that occurs when ComBat is fit on observed labels but downstream permutation reshuffles across sites. Paired / one-sample sign-flip paths are stratum-invariant by construction; ``strata`` is silently accepted but has no effect there. **kwargs Additional keyword arguments (for future extensions). Returns ------- InferenceResult Dict-like result with p-value arrays: - 'negative': P-values for group 1 > group 2. - 'positive': P-values for group 2 > group 1. Raises ------ ValueError If constrained methods (cnbs, ni_tfnbs, fbc_tfnbs) are used without net_labels. Examples -------- >>> import numpy as np >>> np.random.seed(2) >>> group1 = np.random.rand(5, 3, 3) >>> for arr in group1: np.fill_diagonal(arr, 0) >>> group2 = np.random.rand(8, 3, 3) >>> for arr in group2: np.fill_diagonal(arr, 0) >>> # Standard t-test >>> p_vals = compute_p_val(group1, group2, n_permutations=10, ... test_type='two-sample', method='tstat', ... use_mp=False, random_state=0) >>> # TFNBS >>> p_vals = compute_p_val(group1, group2, n_permutations=10, ... test_type='two-sample', method='tfnbs', ... use_mp=False, random_state=0) >>> # cNBS with network labels >>> labels = np.array([0, 0, 1]) >>> p_vals = compute_p_val(group1, group2, n_permutations=10, ... test_type='two-sample', method='cnbs', ... net_labels=labels, use_mp=False, random_state=0) """ # Normalize inputs test_type_str = test_type.value if isinstance(test_type, TestType) else test_type method_str = method.value if isinstance(method, StatMethod) else method if random_state is not None and rng is None: warn_legacy_random_state("random_state") random_state = resolve_seed(rng, legacy_random_state=random_state) # Validate method try: method_enum = StatMethod(method_str) except ValueError: valid_methods = [m.value for m in StatMethod] raise ValueError( f"Invalid method: '{method_str}'. Must be one of: {valid_methods}" ) # Validate constrained methods require net_labels if method_enum in CONSTRAINED_METHODS and net_labels is None: raise ValueError( f"Method '{method_str}' requires net_labels to be provided. " f"Constrained methods are: {[m.value for m in CONSTRAINED_METHODS]}" ) _t0 = time.perf_counter() # Parametric methods: compute p-values directly from t-distribution if method_enum in PARAMETRIC_METHODS: result, obs_t_dict = _compute_parametric_p_values( group1, group2, test_type_str, method_enum ) return InferenceResult( result["positive"], result["negative"], method=method_str, n_permutations=0, acceleration=None, wall_time_s=time.perf_counter() - _t0, stat_positive=obs_t_dict["positive"], stat_negative=obs_t_dict["negative"], stat_type="tstat", ) # BH-FDR with permutation p-values: separate code path if method_enum == StatMethod.BH_FDR_PERM: result, obs_t_dict = _compute_bh_fdr_perm_p_values( group1, group2, test_type_str, n_permutations=n_permutations, random_state=random_state, use_mp=use_mp, n_processes=n_processes, verbose=verbose, strata=strata, ) return InferenceResult( result["positive"], result["negative"], method=method_str, n_permutations=n_permutations, acceleration=None, wall_time_s=time.perf_counter() - _t0, stat_positive=obs_t_dict["positive"], stat_negative=obs_t_dict["negative"], stat_type="tstat", strata_provided=strata is not None, ) # Resolve enhancement wrapper (None for raw t-stat) enhance_func = _ENHANCE_MAP[method_enum] # Build enhancement kwargs once (method-specific) enhance_kwargs: Dict[str, Any] = {} if method_enum == StatMethod.NBS: enhance_kwargs = {"threshold": threshold, "nbs_stat": nbs_stat} elif method_enum in {StatMethod.TFNBS, StatMethod.NI_TFNBS, StatMethod.FBC_TFNBS}: enhance_kwargs = {"e": e, "h": h, "n": n, "start_thres": start_thres} if method_enum in CONSTRAINED_METHODS: enhance_kwargs["net_labels"] = net_labels if method_enum == StatMethod.FBC_TFNBS: enhance_kwargs["min_cluster_size"] = min_cluster_size if method_enum == StatMethod.NI_TFNBS: enhance_kwargs["normalization"] = normalization elif method_enum == StatMethod.CNBS: enhance_kwargs = {"net_labels": net_labels} # Compute observed t-stat once, then apply enhancement (if any) if test_type_str == TestType.PAIRED.value: diffs = group2 - group1 emp_t_dict = compute_t_stat_diff(diffs) elif test_type_str == TestType.ONE_SAMPLE.value: emp_t_dict = compute_t_stat_diff(group1) elif test_type_str == TestType.TWO_SAMPLE.value: emp_t_dict = _compute_t_stat_ind(group1, group2) else: raise ValueError( f"Invalid test_type: '{test_type_str}'. " f"Must be one of: {[t.value for t in TestType]}" ) # Keep a reference to the raw (pre-enhancement) t-stat dict for # InferenceResult.stat_positive/stat_negative — the user-facing # effect map should be the t-statistic, not the enhanced score. raw_t_dict = emp_t_dict if enhance_func is not None: emp_t_dict = enhance_func(emp_t_dict, **enhance_kwargs) # Compute null distribution # - method='tstat' (no enhancement): use compute_null_dist fast path # - enhancement methods: use fast enhancement perm tasks (sums + enhancement) if enhance_func is None: group2_for_null = group2 if test_type_str != TestType.ONE_SAMPLE.value else None tstat_func = ( _compute_t_stat_ind if test_type_str == TestType.TWO_SAMPLE.value else compute_t_stat_diff ) max_null_dict = compute_null_dist( group1, group2_for_null, tstat_func, n_permutations=n_permutations, test_type=test_type, use_mp=use_mp, random_state=random_state, n_processes=n_processes, verbose=verbose, strata=strata, ) else: max_null_dict = _compute_enhanced_null_dist( group1, group2, test_type_str, enhance_func=enhance_func, enhance_kwargs=enhance_kwargs, n_permutations=n_permutations, random_state=random_state, use_mp=use_mp, n_processes=n_processes, verbose=verbose, strata=strata, ) if acceleration is not None: result = compute_p_values_accelerated( emp_t_dict, max_null_dict, method=acceleration, ) else: result = _compute_p_values_from_null(emp_t_dict, max_null_dict) # Capture the (E, H) grid for TFNBS-family methods so the result # carries the parameter axis labels when an array was passed. e_grid, h_grid = _grid_from_kwargs(method_enum, enhance_kwargs) return InferenceResult( result["positive"], result["negative"], method=method_str, n_permutations=n_permutations, acceleration=acceleration, wall_time_s=time.perf_counter() - _t0, stat_positive=raw_t_dict["positive"], stat_negative=raw_t_dict["negative"], stat_type="tstat", strata_provided=strata is not None, e_grid=e_grid, h_grid=h_grid, )
# ============================================================================= # T-statistic computation functions # =============================================================================
[docs] def compute_t_stat( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]] = None, test_type: Union[str, TestType] = TestType.PAIRED ) -> Dict[str, npt.NDArray[np.float64]]: """ Compute t-statistics for paired, one-sample, or two-sample tests. Parameters ---------- group1 : ndarray of shape (n_samples_1, N, N) Data array for group 1. group2 : ndarray of shape (n_samples_2, N, N), optional Data array for group 2. Required for paired/two-sample tests. test_type : {'paired', 'one-sample', 'two-sample'} or TestType Type of statistical test. Returns ------- dict Dictionary with 'positive' and 'negative' t-statistic arrays. Raises ------ ValueError If test_type is invalid or dimensions don't match. """ # Normalize test_type to string test_type_str = test_type.value if isinstance(test_type, TestType) else test_type if test_type_str == TestType.ONE_SAMPLE.value: if group2 is not None: logger.warning("group2 is provided but test_type is 'one-sample'. It will be ignored.") if group1.ndim != 3: raise ValueError("Dimensions of group 1 data should be: (subjects, N, N).") return compute_t_stat_diff(group1) elif test_type_str == TestType.TWO_SAMPLE.value: if group1.shape[1:] != group2.shape[1:]: raise ValueError("Trailing dimensions of group1 and group2 must match.") return _compute_t_stat_ind(group1, group2) elif test_type_str == TestType.PAIRED.value: diffs = group2 - group1 return compute_t_stat_diff(diffs) else: raise ValueError( f"Invalid test_type: '{test_type_str}'. " f"Must be one of: {[t.value for t in TestType]}" )
[docs] def compute_t_stat_diff( diff: npt.NDArray[np.float64] ) -> Dict[str, npt.NDArray[np.float64]]: """ Compute t-statistics for paired differences. Parameters ---------- diff : ndarray of shape ``(n_samples, *dims)`` Array containing paired differences. Returns ------- dict Dictionary with: - 'positive': Positive t-values (where group 2 > group 1). - 'negative': Negative t-values converted to positive. Raises ------ ValueError If fewer than 2 samples are provided. """ n = diff.shape[0] if n < 2: raise ValueError("At least 2 samples required for t-statistic.") x_mean = np.mean(diff, axis=0) x_std = np.std(diff, axis=0, ddof=1) with np.errstate(divide='ignore', invalid='ignore'): t_stat = x_mean / (x_std / np.sqrt(n)) t_stat = np.where(x_std == 0, 0, t_stat) pos_t = np.where(t_stat > 0, t_stat, 0) neg_t = np.where(t_stat < 0, -t_stat, 0) return make_tail_result(pos_t, neg_t)
def _compute_t_stat_ind( group1: npt.NDArray[np.float64], group2: npt.NDArray[np.float64] ) -> Dict[str, npt.NDArray[np.float64]]: """ Compute Welch's t-statistics for independent samples. Parameters ---------- group1 : ndarray of shape (n_samples_1, *dims) Data array for group 1. group2 : ndarray of shape (n_samples_2, *dims) Data array for group 2. Returns ------- dict Dictionary with: - 'positive': Positive t-values (where group 2 > group 1). - 'negative': Negative t-values converted to positive. Raises ------ ValueError If either group has fewer than 2 samples. """ n1, n2 = group1.shape[0], group2.shape[0] if n1 < 2 or n2 < 2: raise ValueError("Each group must have at least 2 samples.") x_mean_1 = np.mean(group1, axis=0) x_mean_2 = np.mean(group2, axis=0) x_var_1 = np.var(group1, axis=0, ddof=1) / n1 x_var_2 = np.var(group2, axis=0, ddof=1) / n2 denominator = np.sqrt(x_var_1 + x_var_2) with np.errstate(divide='ignore', invalid='ignore'): t_stat = (x_mean_2 - x_mean_1) / denominator t_stat = np.where(denominator == 0, 0, t_stat) pos_t = np.where(t_stat > 0, t_stat, 0) neg_t = np.where(t_stat < 0, -t_stat, 0) return make_tail_result(pos_t, neg_t)