Source code for conninfpy.nbs_score

"""
Network-Based Statistics (NBS) scoring module.

Provides classical NBS cluster scoring (extent or intensity) and constrained
cNBS scoring in a single module, mirroring the role of tfnbs_score.py.
"""

from __future__ import annotations

from typing import Dict, Optional, Tuple

import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

from ._rng import RngLike, resolve_seed, warn_legacy_random_state


__all__ = [
    "DEFAULT_NBS_THRESHOLD",
    "DEFAULT_NBS_STAT",
    "get_nbs_score",
    "get_cnbs_score",
    "nbs_bct",
]


# =============================================================================
# Constants — imported from unified defaults
# =============================================================================

from .defaults import DEFAULT_NBS_THRESHOLD, DEFAULT_NBS_STAT


# =============================================================================
# Helper utilities
# =============================================================================

ArrayF = npt.NDArray[np.float64]


def _is_symmetric(matrix: npt.NDArray[np.floating], rtol: float = 1e-10) -> bool:
    return np.allclose(matrix, matrix.T, rtol=rtol, atol=0)


def _get_edges(
    t_stats: npt.NDArray[np.floating],
    min_threshold: float,
    symmetric: bool,
) -> Tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], npt.NDArray[np.floating]]:
    if symmetric:
        rows, cols = np.triu_indices_from(t_stats, k=1)
    else:
        nroi = t_stats.shape[0]
        rows, cols = np.where(~np.eye(nroi, dtype=bool))

    weights = t_stats[rows, cols]
    valid_mask = weights > min_threshold
    return rows[valid_mask], cols[valid_mask], weights[valid_mask]


def _check_square(t_stat: npt.NDArray[np.floating]) -> None:
    if t_stat.ndim != 2 or t_stat.shape[0] != t_stat.shape[1]:
        raise ValueError("t_stat must be a square (N, N) matrix.")
    if not np.allclose(np.diag(t_stat), 0.0):
        raise ValueError("Diagonal elements must be zero (no self-connections).")


def _validate_net_labels(
    net_labels: npt.NDArray[np.integer],
    n_nodes: int,
) -> npt.NDArray[np.integer]:
    if net_labels.shape[0] != n_nodes:
        raise ValueError(
            f"net_labels shape {net_labels.shape} does not match number of nodes {n_nodes}."
        )
    _, inverse = np.unique(net_labels, return_inverse=True)
    return inverse


def _compute_canonical_block_ids(
    edge_rows: npt.NDArray[np.intp],
    edge_cols: npt.NDArray[np.intp],
    node_labels: npt.NDArray[np.integer],
    n_networks: int,
) -> npt.NDArray[np.intp]:
    labels_i = node_labels[edge_rows]
    labels_j = node_labels[edge_cols]
    min_labels = np.minimum(labels_i, labels_j)
    max_labels = np.maximum(labels_i, labels_j)
    return min_labels * n_networks + max_labels


# =============================================================================
# NBS scoring
# =============================================================================

[docs] def get_nbs_score( t_stats: npt.NDArray[np.floating], threshold: float = DEFAULT_NBS_THRESHOLD, stat_type: str = DEFAULT_NBS_STAT, ) -> npt.NDArray[np.floating]: """ Compute NBS cluster statistics for a t-statistic matrix. Parameters ---------- t_stats : ndarray of shape (N, N) Non-negative t-statistic matrix (one tail, already separated). threshold : float T-statistic threshold for edge inclusion. stat_type : {'extent', 'intensity'} - 'extent': cluster size (# edges). - 'intensity': sum of t-values within the cluster. Returns ------- ndarray of shape (N, N) Matrix where each suprathreshold edge is assigned its cluster statistic. Non-suprathreshold edges are 0. """ if stat_type not in ("extent", "intensity"): raise ValueError("stat_type must be 'extent' or 'intensity'.") _check_square(t_stats) # Round to avoid float precision issues at threshold boundaries t_stats = np.round(t_stats, decimals=10) nroi = t_stats.shape[0] scores = np.zeros((nroi, nroi), dtype=np.float64) is_symm = _is_symmetric(t_stats) edge_rows, edge_cols, edge_weights = _get_edges(t_stats, threshold, symmetric=is_symm) if edge_rows.size == 0: return scores mask = np.zeros((nroi, nroi), dtype=bool) mask[edge_rows, edge_cols] = True if is_symm: mask[edge_cols, edge_rows] = True mask_for_cc = mask | mask.T sparse_mat = csr_matrix(mask_for_cc) n_components, node_labels = connected_components(sparse_mat, directed=False) edge_component_ids = node_labels[edge_rows] if stat_type == "extent": component_stats = np.bincount(edge_component_ids, minlength=n_components).astype(np.float64) else: component_stats = np.bincount( edge_component_ids, weights=edge_weights.astype(np.float64), minlength=n_components, ) edge_stats = component_stats[edge_component_ids] scores[edge_rows, edge_cols] = edge_stats if is_symm: scores[edge_cols, edge_rows] = edge_stats return scores
# ============================================================================= # cNBS scoring # =============================================================================
[docs] def get_cnbs_score( t_stat: npt.NDArray[np.float64], net_labels: npt.NDArray[np.int_], ) -> npt.NDArray[np.float64]: """ Compute cNBS score: mean t-stat per subnetwork. Parameters ---------- t_stat : ndarray of shape (N, N) Non-negative t-statistic matrix (one tail, already separated). net_labels : ndarray of shape (N,) Network label for each node (0..K-1). Returns ------- ndarray of shape (N, N) Matrix where each edge has its subnetwork's mean t-stat. Edges below threshold (t_stat=0) get score=0. """ _check_square(t_stat) if net_labels.ndim != 1 or net_labels.shape[0] != t_stat.shape[0]: raise ValueError("net_labels must have shape (N,) matching t_stat.") # Round to avoid float precision issues at threshold boundaries t_stat = np.round(t_stat, decimals=10) is_symm = _is_symmetric(t_stat) normalized_labels = _validate_net_labels(np.asarray(net_labels), t_stat.shape[0]) n_networks = int(np.max(normalized_labels) + 1) edge_rows, edge_cols, edge_weights = _get_edges(t_stat, 0.0, symmetric=is_symm) if edge_rows.size == 0: return np.zeros_like(t_stat, dtype=np.float64) block_ids = _compute_canonical_block_ids(edge_rows, edge_cols, normalized_labels, n_networks) max_block_id = n_networks * n_networks sums = np.bincount(block_ids, weights=edge_weights.astype(np.float64), minlength=max_block_id) counts = np.bincount(block_ids, minlength=max_block_id) means = np.divide(sums, counts, out=np.zeros_like(sums), where=counts > 0) edge_scores = means[block_ids] scores = np.zeros_like(t_stat, dtype=np.float64) scores[edge_rows, edge_cols] = edge_scores if is_symm: scores[edge_cols, edge_rows] = edge_scores return scores
# ============================================================================= # Reference NBS (comparison/testing only) # ============================================================================= def _score_nbs_from_diffs( diffs: npt.NDArray[np.float64], threshold: float, stat_type: str, ) -> Dict[str, npt.NDArray[np.float64]]: from ._compat import make_tail_result from .pairwise_stats import compute_t_stat_diff t_stat_dict = compute_t_stat_diff(diffs) return make_tail_result( get_nbs_score(t_stat_dict["positive"], threshold=threshold, stat_type=stat_type), get_nbs_score(t_stat_dict["negative"], threshold=threshold, stat_type=stat_type), ) def _score_nbs_two_sample( group1: npt.NDArray[np.float64], group2: npt.NDArray[np.float64], test_type: str, threshold: float, stat_type: str, ) -> Dict[str, npt.NDArray[np.float64]]: from ._compat import make_tail_result from .pairwise_stats import compute_t_stat t_stat_dict = compute_t_stat(group1, group2, test_type=test_type) return make_tail_result( get_nbs_score(t_stat_dict["positive"], threshold=threshold, stat_type=stat_type), get_nbs_score(t_stat_dict["negative"], threshold=threshold, stat_type=stat_type), )
[docs] def nbs_bct( group1: npt.NDArray[np.float64], group2: Optional[npt.NDArray[np.float64]] = None, threshold: float = DEFAULT_NBS_THRESHOLD, stat_type: str = DEFAULT_NBS_STAT, n_permutations: int = 100, test_type: str = "paired", use_mp: bool = True, random_state: Optional[int] = None, # deprecated alias for `rng` rng: RngLike = None, n_processes: Optional[int] = None, **kwargs, ) -> Tuple[ Dict[str, npt.NDArray[np.float64]], Dict[str, npt.NDArray[np.uint8]], Dict[str, npt.NDArray[np.float64]], ]: """ Compute classical NBS with permutation testing. Returns p-values, adjacency matrices (thresholded t-stats), and null maxima. P-values use tie-inclusive max-statistic counting and the Phipson-Smyth +1 correction, matching the main ``compute_p_val`` path. """ from .pairwise_stats import compute_null_dist, compute_t_stat, compute_t_stat_diff if random_state is not None and rng is None: warn_legacy_random_state("random_state") random_state = resolve_seed(rng, legacy_random_state=random_state) if test_type == "paired": diffs = group2 - group1 emp_t_dict = _score_nbs_from_diffs(diffs, threshold=threshold, stat_type=stat_type) t_func = _score_nbs_from_diffs elif test_type == "two-sample": emp_t_dict = _score_nbs_two_sample(group1, group2, test_type="two-sample", threshold=threshold, stat_type=stat_type) t_func = _score_nbs_two_sample elif test_type == "one-sample": if group2 is not None: raise ValueError("group2 must be None for one-sample tests.") if group1.ndim != 3: raise ValueError("Dimensions of group 1 data should be: (subjects, N, N).") emp_t_dict = _score_nbs_from_diffs(group1, threshold=threshold, stat_type=stat_type) t_func = _score_nbs_from_diffs else: raise ValueError("test_type must be one of {'paired', 'one-sample', 'two-sample'}.") # Build adjacency masks for reporting (raw thresholding on t-stats) adj_matrices: Dict[str, npt.NDArray[np.uint8]] = {} if test_type == "paired": raw_t = compute_t_stat_diff(group2 - group1) elif test_type == "two-sample": raw_t = compute_t_stat(group1, group2, test_type="two-sample", **kwargs) else: raw_t = compute_t_stat_diff(group1) for key in raw_t: adj = (raw_t[key] > threshold).astype(np.uint8) if adj.shape[-1] == adj.shape[-2]: adj = np.triu(adj, 1) adj = adj + adj.T adj_matrices[key] = adj # Null distribution for max cluster statistic group2_for_null = group2 if test_type != "one-sample" else None max_null_dict = compute_null_dist( group1, group2_for_null, t_func, n_permutations=n_permutations, test_type=test_type, use_mp=use_mp, random_state=random_state, n_processes=n_processes, threshold=threshold, stat_type=stat_type, ) # P-values via max-stat (cluster extent or intensity) keys = list(emp_t_dict.keys()) p_values: Dict[str, npt.NDArray[np.float64]] = {} for key in keys: emp_t = emp_t_dict[key][..., np.newaxis] null_dist = max_null_dict[key] count = np.sum(emp_t <= null_dist, axis=-1) p_values[key] = (count + 1.0) / (null_dist.shape[0] + 1.0) return p_values, adj_matrices, max_null_dict