from __future__ import annotations
import warnings
# Suppress known deprecation warnings from SpatialData dependencies BEFORE importing anything else.
warnings.filterwarnings("ignore", category=FutureWarning, message=".*legacy Dask DataFrame.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources is deprecated.*")
import logging
from typing import Any
import numpy as np
import pandas as pd
import scipy.fft
import spatialdata as sd
from joblib import Parallel, delayed
from scipy.stats import norm
from tqdm import tqdm
from quadsv.detectors.base import Detector
from quadsv.kernels.fft import FFTKernel
from quadsv.statistics import spatial_q_test
from quadsv.utils import _apply_bh_correction
__all__ = ["DetectorGrid"]
logger = logging.getLogger(__name__)
def _qstat_worker_fft(
raster_layer, feature_batch: list[str], kernel: FFTKernel, return_pval: bool
) -> list[dict]:
"""
Worker function for parallel Q-statistic computation with FFT kernels.
Parameters
----------
raster_layer : xarray.DataArray
Rasterized data layer (lazy Dask array).
feature_batch : List[str]
Feature names to process in this batch.
kernel : FFTKernel
Pre-constructed FFT kernel object.
return_pval : bool
Whether to compute p-values.
Returns
-------
List[dict]
Results for each feature: {'Feature': str, 'Q': float, 'P_value': float, 'Z_score': float}
"""
results = []
# Load data to memory for batch: shape (M, ny, nx)
data_chunk = raster_layer.sel(c=feature_batch).values
# Transpose to (ny, nx, M) for kernel
data_chunk_transposed = np.moveaxis(data_chunk, 0, -1)
# Compute statistics
if return_pval:
stats, pvals = spatial_q_test(data_chunk_transposed, kernel, return_pval=True)
else:
stats = spatial_q_test(data_chunk_transposed, kernel, return_pval=False)
pvals = None
# Ensure array semantics for iteration (handle 0-d arrays)
stats = np.atleast_1d(np.asarray(stats))
if pvals is not None:
pvals = np.atleast_1d(np.asarray(pvals, dtype=object))
else:
pvals = np.array([None] * len(feature_batch), dtype=object)
# Null parameters
mu = kernel.trace()
sigma = np.sqrt(2.0 * kernel.square_trace())
z_scores = (stats - mu) / sigma if sigma > 1e-12 else np.zeros_like(stats)
# Format batch results
for j, gene in enumerate(feature_batch):
results.append(
{
"Feature": gene,
"Q": float(stats[j]),
"P_value": float(pvals[j]) if pvals[j] is not None else None,
"Z_score": float(z_scores[j]),
}
)
return results
[docs]
class DetectorGrid(Detector):
r"""
Detect spatial patterns on **regular grids** (SpatialData bins) with
FFT-accelerated kernel tests.
Univariate (Q-test) and bivariate (R-test) kernel-based spatial statistics
on rasterized :class:`spatialdata.SpatialData` bins.
Workflow
--------
1. **Construct** with kernel method + kernel hyperparameters / grid controls.
2. **Setup** with :meth:`setup_data` passing the :class:`spatialdata.SpatialData`
plus the bin / table / col / row keys. Setup rasterizes the table and
builds the :class:`~quadsv.FFTKernel` at the resulting grid shape.
3. **Compute** with :meth:`compute_qstat` / :meth:`compute_rstat`.
Parameters
----------
kernel_method : str, default ``'car'``
One of ``'gaussian'``, ``'matern'``, ``'moran'``, ``'graph_laplacian'``,
``'car'``.
**kernel_params
Kernel hyperparameters plus grid controls (``spacing``, ``topology``,
``fft_solver``, ``workers``). See :class:`~quadsv.FFTKernel`.
Attributes
----------
sdata : :class:`spatialdata.SpatialData` or None
Input container set by :meth:`setup_data`.
min_count : int or None
Feature count threshold; set by :meth:`setup_data`.
kernel\_ : :class:`~quadsv.FFTKernel` or None
Built in :meth:`setup_data` once the grid shape is known.
kernel_method\_, kernel_params\_, n
See :class:`Detector`.
Examples
--------
>>> det = DetectorGrid(kernel_method='car', rho=0.8)
>>> det.setup_data(sdata, bins='grid', table_name='table',
... col_key='col_idx', row_key='row_idx') # doctest: +SKIP
>>> q = det.compute_qstat(features=['Gene_1', 'Gene_2']) # doctest: +SKIP
"""
def __init__(self, kernel_method: str = "car", **kernel_params: Any) -> None:
super().__init__(kernel_method, **kernel_params)
# Data-state attrs (populated by setup_data):
[docs]
self.sdata: sd.SpatialData | None = None
"""Reference to the input :class:`spatialdata.SpatialData`, set by :meth:`setup_data`."""
[docs]
self.min_count: int | None = None
"""Minimum total count per feature applied in :meth:`setup_data`."""
# Rasterization keys (populated by setup_data):
self._img_key: str | None = None
self._table_name: str | None = None
self._bins: str | None = None
self._col_key: str | None = None
self._row_key: str | None = None
def _merge_kernel_defaults(self, method: str, user_params: dict) -> dict:
"""Merge grid-level + per-method FFTKernel defaults with user overrides."""
general_defaults = {
"spacing": (1.0, 1.0),
"topology": "square",
"fft_solver": "fft2",
"workers": None,
}
method_defaults = {
"gaussian": {"bandwidth": 2.0},
"matern": {"nu": 1.5, "bandwidth": 2.0},
"moran": {"neighbor_degree": 1},
"graph_laplacian": {"neighbor_degree": 1},
"car": {"rho": 0.9, "neighbor_degree": 1},
}
defaults = {**general_defaults, **method_defaults.get(method, {})}
for key, value in user_params.items():
if key not in defaults:
raise ValueError(
f"Unknown parameter {key!r} for method {method!r}. "
f"Allowed: {sorted(defaults)}."
)
defaults[key] = value
return defaults
[docs]
def setup_data(
self,
sdata: sd.SpatialData,
*,
bins: str,
table_name: str,
col_key: str,
row_key: str,
value_key: str | None = None,
min_count: int | None = None,
) -> DetectorGrid:
"""
Attach ``sdata``, rasterize the chosen bins table, and build the FFTKernel.
Parameters
----------
sdata : :class:`spatialdata.SpatialData`
Input container.
bins : str
Name of the SpatialElement (Shape) defining the grid-like bins.
table_name : str
Name of the table annotating the SpatialElement in ``sdata.tables``.
col_key, row_key : str
``.obs`` columns holding integer column / row indices for the bins.
value_key : str, optional
Value column in ``.obs`` to rasterize. ``None`` uses counts / presence.
min_count : int, optional
Minimum total count for a feature to pass filtering. ``None`` disables.
Returns
-------
self : DetectorGrid
"""
self.sdata = sdata
self.min_count = min_count
self._bins = bins
self._table_name = table_name
self._col_key = col_key
self._row_key = row_key
# Rasterize once, store the resulting image key.
self._img_key = self._rasterize_bins(
bins=bins,
table_name=table_name,
col_key=col_key,
row_key=row_key,
value_key=value_key,
)
raster_layer = self.sdata[self._img_key]
_, ny, nx = raster_layer.shape
self.n = ny * nx
logger.info(
"Building FFTKernel (%s) for grid shape (%d, %d)...",
self.kernel_method_,
ny,
nx,
)
self.kernel_ = FFTKernel(shape=(ny, nx), method=self.kernel_method_, **self.kernel_params_)
self._data_ready = True
return self
def _rasterize_bins(
self,
bins: str,
table_name: str,
col_key: str,
row_key: str,
value_key: str | None = None,
return_region_as_labels: bool = False,
) -> str:
"""
Wrapper for spatialdata.rasterize_bins with format validation.
Converts sparse table into a rasterized (grid) image. Ensures CSC sparse
format for efficient processing and stores result in sdata.images.
Parameters
----------
bins : str
Name of the SpatialElement (Shape) which defines the grid-like bins.
table_name : str
Name of the table annotating the SpatialElement in sdata.tables.
col_key : str
Column in sdata[table_name].obs containing column indices (integers) for bins.
row_key : str
Column in sdata[table_name].obs containing row indices (integers) for bins.
value_key : str, optional
Column in sdata[table_name].obs to use as pixel values. If None, uses counts/presence.
return_region_as_labels : bool, default False
If True, returns bin region masks as integer labels. If False, returns aggregated values.
Returns
-------
img_key : str
Key under which the rasterized image is stored in sdata.images.
Format: 'rasterized_{table_name}'.
Notes
-----
This method ensures the underlying matrix is in CSC sparse format for efficient
column-wise operations required by rasterize_bins.
"""
from quadsv._rasterize import rasterize_table
img_key = f"rasterized_{table_name}"
logger.info("Rasterizing %s into %s...", table_name, img_key)
rasterized = rasterize_table(
self.sdata,
bins=bins,
table_name=table_name,
col_key=col_key,
row_key=row_key,
value_key=value_key,
return_region_as_labels=return_region_as_labels,
)
self.sdata[img_key] = rasterized
return img_key
def _filter_features(self, features, table_name):
"""Helper to validate and filter features based on min_count."""
# Extract all features from table
all_features = self.sdata.tables[table_name].var_names.to_list()
if features is None:
# Use all features
valid = all_features
else:
valid = [f for f in features if f in all_features]
if not valid:
raise ValueError("No valid features found.")
valid = np.array(valid)
# Apply min_count filter
if self.min_count is not None:
counts = np.asarray(self.sdata.tables[table_name][:, valid].X.sum(axis=0)).ravel()
valid = valid[counts >= self.min_count]
if len(valid) == 0:
raise ValueError(f"No features passed min_count={self.min_count}")
return valid
# ------------------------------------------------------------------
# Auto-tuning helpers
# ------------------------------------------------------------------
def _auto_chunk_size(self, budget_bytes: int = 2 * (1 << 30)) -> int:
"""Thin wrapper around :func:`quadsv.statistics.auto_chunk_size`.
Delegates to the shared helper so the FFT chunk-size policy
(cache sweet spot of 32, per-feature ``~24·n`` bytes) is kept
in one place — see :func:`~quadsv.statistics.auto_chunk_size`
for the full model.
"""
from quadsv.statistics import auto_chunk_size
return auto_chunk_size(self.kernel_, budget_bytes=budget_bytes)
def _auto_schedule(
self, n_batches: int, n_jobs: int | str, workers: int | str | None
) -> tuple[int, int | None]:
"""Balance joblib ``n_jobs`` and scipy.fft ``workers`` to the CPU count.
Both ``n_jobs`` and ``workers`` parallelize, and stacking them thrashes
cores. ``'auto'`` policy:
- If ``n_batches >= cpu_count``: parallelize across batches
(``n_jobs=cpu_count``), let each FFT call be single-threaded (``workers=1``).
- Otherwise (few batches, big grids): cap ``n_jobs`` at ``n_batches`` and
give each worker ``cpu_count / n_jobs`` FFT threads.
Concrete integers passed by the caller are respected.
"""
import os
cpu = os.cpu_count() or 1
if n_jobs == "auto" or n_jobs == -1:
if n_batches >= cpu:
n_jobs_resolved = cpu
workers_resolved = 1 if workers == "auto" else workers
else:
n_jobs_resolved = max(1, n_batches)
workers_resolved = (
max(1, cpu // max(1, n_batches)) if workers == "auto" else workers
)
else:
n_jobs_resolved = int(n_jobs)
workers_resolved = (
max(1, cpu // max(1, n_jobs_resolved)) if workers == "auto" else workers
)
return n_jobs_resolved, workers_resolved
[docs]
def compute_qstat(
self,
features: list[str] | None = None,
n_jobs: int | str = "auto",
workers: int | str | None = "auto",
return_pval: bool = True,
chunk_size: int | str = "auto",
show_progress: bool = True,
) -> pd.DataFrame:
"""
Compute the spatial Q-statistic across features in parallel.
Requires :meth:`setup_data` to have been called; rasterization and
kernel construction happen there. This method pulls the rasterized
feature tensor from :attr:`sdata` and runs per-feature FFT Q-tests.
Parameters
----------
features : list of str, optional
Feature names to analyze. ``None`` uses all features that pass
the ``min_count`` filter from :meth:`setup_data`.
n_jobs : int or ``'auto'``, default ``'auto'``
Joblib workers over feature batches. ``'auto'`` balances against
``workers`` — see :meth:`_auto_schedule`. ``-1`` is also accepted
and behaves like ``'auto'``.
workers : int, ``'auto'``, or None, default ``'auto'``
Threads for scipy.fft inside each worker. ``'auto'`` co-balances with
``n_jobs``; ``None`` defers to scipy's default.
return_pval : bool, default True
Whether to compute p-values + Benjamini–Hochberg–adjusted p-values.
chunk_size : int or ``'auto'``, default ``'auto'``
Features per worker batch. ``'auto'`` resolves to ``~256 MB / (ny·nx·24)``
via :meth:`_auto_chunk_size` and clips to ``[16, 1024]``.
show_progress : bool, default True
Show a tqdm progress bar over worker chunks.
Returns
-------
pandas.DataFrame
Indexed by feature. Columns: ``Q``, ``Z_score``, and (if
``return_pval=True``) ``P_value``, ``P_adj``. Sorted by ``Q`` desc.
"""
self._require_setup()
raster_layer = self.sdata[self._img_key]
features = self._filter_features(features, self._table_name)
if isinstance(chunk_size, str):
if chunk_size != "auto":
raise ValueError(f"chunk_size must be 'auto' or int, got {chunk_size!r}.")
chunk_size = self._auto_chunk_size()
feature_batches = [
features[i : i + chunk_size] for i in range(0, len(features), chunk_size)
]
n_jobs, workers = self._auto_schedule(len(feature_batches), n_jobs, workers)
# Let the FFT path pick up the balanced workers setting.
self.kernel_.workers = workers
logger.info(
"Q-test on %d features — %d batches, n_jobs=%d, workers=%s, chunk_size=%d",
len(features),
len(feature_batches),
n_jobs,
workers,
chunk_size,
)
batch_iter = feature_batches
if show_progress:
batch_iter = tqdm(
feature_batches,
desc=f"Q ({self.kernel_method_})",
bar_format="{l_bar}{bar:30}{r_bar}{bar:-30b}",
)
results_list = Parallel(n_jobs=n_jobs, prefer="threads")(
delayed(_qstat_worker_fft)(raster_layer, batch, self.kernel_, return_pval)
for batch in batch_iter
)
# 7. Flatten results
results = [item for sublist in results_list for item in sublist]
# 5. Compile results
df = pd.DataFrame(results).set_index("Feature")
if not return_pval:
df = df.drop(columns=["P_value"])
# 6. Multiple testing correction (Benjamini-Hochberg) in place
if return_pval:
_apply_bh_correction(df)
return df.sort_values(by="Q", ascending=False)
def _compute_batch_spectral_embeddings(self, raster_layer, feature_names):
"""
Helper: Loads data, standardizes, and computes weighted spectral components.
Returns matrix of shape (n_features, n_spectral_components).
"""
# 1. Load Data (IO Bound)
# Shape: (N_features, Y, X)
data = raster_layer.sel(c=feature_names).values
n_feats, ny, nx = data.shape
# 2. Standardize (In-place to save memory)
# Mean/Std per feature
means = np.mean(data, axis=(1, 2), keepdims=True)
stds = np.std(data, axis=(1, 2), keepdims=True, ddof=1)
# Avoid div by zero
stds[stds < 1e-12] = 1.0
data = (data - means) / stds
# 3. FFT and Spectral Weighting
# Use selected FFT solver
if self.kernel_.fft_solver == "fft2":
freq_data = scipy.fft.fft2(data, axes=(1, 2), workers=self.kernel_.workers)
rfft_spectrum = self.kernel_.eigenvalues().reshape(ny, nx)
else:
freq_data = scipy.fft.rfft2(data, axes=(1, 2), workers=self.kernel_.workers)
rfft_spectrum = self.kernel_.eigenvalues().reshape(ny, nx // 2 + 1)
weights = np.sqrt(np.abs(rfft_spectrum))
# Broadcast multiply
weighted_freq = freq_data * weights[None, :, :]
# 4. Flatten spatial dimensions for matrix multiplication
# Result: (N_features, n_freq_bins)
return weighted_freq.reshape(n_feats, -1)
[docs]
def compute_rstat( # noqa: C901
self,
features_x: list[str] | None = None,
features_y: list[str] | None = None,
return_pval: bool = True,
chunk_size: int | str = "auto",
workers: int | str | None = "auto",
show_progress: bool = True,
) -> pd.DataFrame:
"""
Compute the bivariate spatial R-statistic across feature pairs.
Requires :meth:`setup_data` to have been called.
Parameters
----------
features_x : list of str, optional
Features for the X variable. If ``None`` and ``features_y`` is
``None``, uses all features (symmetric pairwise mode).
features_y : list of str, optional
Features for the Y variable. If ``None``, pairs are drawn from
``features_x`` (symmetric, upper-triangular). If provided, returns
all X × Y pairs (bipartite).
return_pval : bool, default True
Whether to compute p-values + Benjamini–Hochberg–adjusted p-values.
chunk_size : int or ``'auto'``, default ``'auto'``
Y-features per batch (reuses the pre-computed ``K @ Y`` block).
``'auto'`` targets ~256 MB per embedding batch via
:meth:`_auto_chunk_size`.
workers : int, ``'auto'``, or None, default ``'auto'``
Threads for scipy.fft inside the embedding pass. ``'auto'`` gives
every FFT all CPU cores (the R-test loop is sequential over X/Y
chunk pairs so there is no joblib contention).
show_progress : bool, default True
Show a tqdm progress bar over X chunks.
Returns
-------
pandas.DataFrame
Columns ``Feature_1``, ``Feature_2``, ``R``, ``Z_score`` and (if
``return_pval=True``) ``P_value``, ``P_adj``. Sorted by ``R`` desc.
"""
import gc # Garbage collector
self._require_setup()
raster_layer = self.sdata[self._img_key]
table_name = self._table_name
_, ny, nx = raster_layer.shape
if isinstance(chunk_size, str):
if chunk_size != "auto":
raise ValueError(f"chunk_size must be 'auto' or int, got {chunk_size!r}.")
chunk_size = self._auto_chunk_size()
# compute_rstat is sequential across X-chunks, so give every FFT the
# full CPU budget by default.
if workers == "auto":
import os
workers = os.cpu_count() or 1
self.kernel_.workers = workers
# 2. Resolve Features
all_features = raster_layer.coords["c"].values
if features_x is None and features_y is None:
features_x = all_features
features_y = None
mode = "symmetric"
elif features_x is not None and features_y is None:
mode = "symmetric"
else:
mode = "bipartite"
features_x = self._filter_features(features_x, table_name)
if mode == "bipartite":
features_y = self._filter_features(features_y, table_name)
# 3. Prepare Batches
chunks_x = np.array_split(features_x, np.ceil(len(features_x) / chunk_size))
if mode == "bipartite":
chunks_y = np.array_split(features_y, np.ceil(len(features_y) / chunk_size))
else:
chunks_y = chunks_x
logger.info(
"Computing R-stats: %d x %d matrix.",
len(features_x),
len(features_y) if features_y is not None else len(features_x),
)
logger.info("Processing in %d chunks of size ~%d...", len(chunks_x), chunk_size)
sigma = np.sqrt(self.kernel_.square_trace())
results_list = []
# 5. Block Iteration
x_iter = tqdm(chunks_x, desc="Processing X chunks") if show_progress else chunks_x
for i, batch_x_names in enumerate(x_iter):
# Load Embeddings X (High Memory Usage)
embeddings_x = self._compute_batch_spectral_embeddings(raster_layer, batch_x_names)
start_j = i if mode == "symmetric" else 0
for j in range(start_j, len(chunks_y)):
batch_y_names = chunks_y[j]
# Load Embeddings Y
if mode == "symmetric" and i == j:
embeddings_y = embeddings_x # Reference, no copy
else:
embeddings_y = self._compute_batch_spectral_embeddings(
raster_layer, batch_y_names
)
# --- CROSS-BATCH CORRELATION ---
# R_block shape: (chunk_size, chunk_size) -> Very small
# This step reduces millions of pixels down to a simple correlation number
R_block = np.matmul(embeddings_x, embeddings_y.conj().T).real
# Normalize by grid size (rfft2 is unnormalized)
R_block /= nx * ny
# --- Format Results ---
# Create meshgrid of indices for this block
n_x = len(batch_x_names)
n_y = len(batch_y_names)
# Create coordinate grids
# If symmetric and diagonal block, we only want upper triangle
if mode == "symmetric" and i == j:
# Get upper triangle indices
r_idx, c_idx = np.triu_indices(n_x)
# Extract values
r_vals = R_block[r_idx, c_idx]
feat_1 = batch_x_names[r_idx]
feat_2 = batch_y_names[c_idx]
else:
# Full block
# Flatten the block
r_vals = R_block.ravel()
# Repeat X names for rows, Tile Y names for cols
feat_1 = np.repeat(batch_x_names, n_y)
feat_2 = np.tile(batch_y_names, n_x)
# Store block results
batch_df = pd.DataFrame({"Feature_1": feat_1, "Feature_2": feat_2, "R": r_vals})
if return_pval:
if sigma > 1e-12:
z_scores = r_vals / sigma
p_vals = 2 * norm.sf(np.abs(z_scores))
else:
z_scores = np.zeros_like(r_vals)
p_vals = np.ones_like(r_vals)
batch_df["Z_score"] = z_scores
batch_df["P_value"] = p_vals
results_list.append(batch_df)
# Explicit cleanup for Y
if not (mode == "symmetric" and i == j):
del embeddings_y
# Explicit cleanup for X
del embeddings_x
gc.collect() # Force memory release before next big load
# 6. Finalize
if not results_list:
return pd.DataFrame(columns=["Feature_1", "Feature_2", "R", "P_value", "P_adj"])
final_df = pd.concat(results_list, ignore_index=True)
if return_pval and not final_df.empty:
_apply_bh_correction(final_df)
return final_df.sort_values(by="R", key=abs, ascending=False)