"""
:class:`ComparatorIrregular` — cross-sample pattern comparison on a
list of :class:`anndata.AnnData` (irregular spots, NUFFT backend).
"""
from __future__ import annotations
import logging
from collections.abc import Sequence
from typing import Any
import anndata as _ad
import numpy as np
import scipy.sparse as sp
from joblib import Parallel, delayed
from tqdm.auto import tqdm
from quadsv.comparators.base import (
_ComparatorBase,
_validate_common,
)
from quadsv.comparators.multisample import radial_bin_spectrum
__all__ = ["ComparatorIrregular"]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# ComparatorIrregular — AnnData / irregular spots
# ---------------------------------------------------------------------------
[docs]
class ComparatorIrregular(_ComparatorBase):
"""
Cross-sample pattern comparison on irregular spots via NUFFT.
Accepts a list of :class:`anndata.AnnData` (one per sample). For each
sample, the per-sample ``obsm[obsm_key]`` supplies the irregular
``(y, x)`` coordinates and ``.X`` (or ``.layers[layer]`` when set) is the
expression matrix. Spectra are evaluated with a batched type-1 NUFFT
(``finufft.nufft2d1``), densifying at most :attr:`nufft_chunk_size`
columns of ``.X`` at a time so the full slab is never materialized.
Parameters
----------
samples : sequence of :class:`anndata.AnnData`
gene_names : sequence of str, optional
If None, inferred from the first sample; every other sample must share
the same ``var_names``.
feature_mode : {'radial', '2d'}, default 'radial'
n_radial_bins : int, default 30
obsm_key : str, default 'spatial'
layer : str, optional
unit_scales : sequence of float, optional
Per-sample multiplier applied to coords before NUFFT (e.g. pixels→μm).
grid_shape, spacing : optional
When both given, used for every sample. Otherwise each sample's
k-grid is auto-inferred from coords via
:func:`quadsv.kernels.nufft._infer_grid_from_coords`.
freq_edges : np.ndarray, optional
eps : float, default 1e-6
NUFFT tolerance.
presence_threshold : float, default 0.0
Minimum fraction of non-zero spots for a gene to count as "observed"
in a sample (feeds :attr:`presence_` and, transitively, the masked
pattern test).
nufft_chunk_size : int, default 64
Number of genes per batched NUFFT call. 32–128 balances finufft's
per-call overhead against the `(n_spots, chunk)` transient RAM.
workers : int, optional
Forwarded to per-sample FFTs used by :meth:`normalize_covariates`.
Notes
-----
The comparator carries no design / contrast state — supply the
cross-sample contrast directly to :meth:`test_diff_freq` /
:meth:`test_diff_expr`. A single fitted comparator can therefore
serve any number of unrelated comparisons on the same spectra.
"""
def __init__(
self,
samples: Sequence[Any],
gene_names: Sequence[str] | None = None,
*,
feature_mode: str = "radial",
n_radial_bins: int = 30,
obsm_key: str = "spatial",
layer: str | None = None,
unit_scales: Sequence[float] | None = None,
grid_shape: tuple[int, int] | None = None,
spacing: tuple[float, float] | None = None,
freq_edges: np.ndarray | None = None,
eps: float = 1e-6,
presence_threshold: float = 0.0,
nufft_chunk_size: int = 64,
workers: int | None = None,
) -> None:
fft_solver = _validate_common(feature_mode, "fft2", presence_threshold)
samples_list = list(samples)
if len(samples_list) == 0:
raise ValueError("samples must be a non-empty list.")
for i, s in enumerate(samples_list):
if not isinstance(s, _ad.AnnData):
raise TypeError(f"sample {i} is {type(s).__name__}, expected anndata.AnnData.")
resolved = _resolve_anndata_gene_names(samples_list, gene_names, layer=layer)
[docs]
self.samples = samples_list
[docs]
self.gene_names = list(resolved)
[docs]
self.feature_mode = feature_mode
[docs]
self.freq_edges = None if freq_edges is None else np.asarray(freq_edges, dtype=float)
# Private (internal-config) state.
self._n_radial_bins = int(n_radial_bins)
self._fft_solver = fft_solver
self._workers = workers
self._presence_threshold = float(presence_threshold)
self._nufft_chunk_size = max(1, int(nufft_chunk_size))
# NUFFT always produces full-2D layout (fft2), regardless of user's
# ``fft_solver`` (which is moot here).
self._spectrum_fft_solver = "fft2"
self._layer = layer
self._obsm_key = obsm_key
self._nufft_eps = float(eps)
# Per-sample coords / grids.
from quadsv.kernels.nufft import _infer_grid_from_coords
if unit_scales is None:
unit_scales = [1.0] * len(samples_list)
if len(unit_scales) != len(samples_list):
raise ValueError(
f"unit_scales length {len(unit_scales)} does not match "
f"n_samples={len(samples_list)}."
)
self._unit_scales: list[float] = [float(s) for s in unit_scales]
coords_list: list[np.ndarray] = []
grids: list[tuple[int, int]] = []
spacings: list[tuple[float, float]] = []
for i, ad_s in enumerate(samples_list):
if obsm_key not in ad_s.obsm:
raise KeyError(
f"sample {i} has no obsm['{obsm_key}']; "
f"available: {list(ad_s.obsm.keys())}."
)
c = np.asarray(ad_s.obsm[obsm_key], dtype=np.float64)
if c.ndim != 2 or c.shape[1] != 2:
raise ValueError(f"sample {i} obsm['{obsm_key}'] must be (n, 2), got {c.shape}.")
coords_list.append(c)
if grid_shape is not None and spacing is not None:
gs_i = (int(grid_shape[0]), int(grid_shape[1]))
sp_i = (float(spacing[0]), float(spacing[1]))
else:
gs_i, sp_i = _infer_grid_from_coords(c * self._unit_scales[i], oversample=2.0)
grids.append(gs_i)
spacings.append(sp_i)
self._coords = coords_list
self._grid_shapes = grids
self._spacings = spacings
# ------------------------------------------------------------------
def _compute_spectra( # noqa: C901
self, n_jobs: int, progress: bool
) -> tuple[list[np.ndarray], np.ndarray, np.ndarray]:
from quadsv.kernels.nufft import power_spectrum_2d_nufft
chunk_size = self._nufft_chunk_size
n_samples_total = len(self.samples)
def _one(i: int, pbar: tqdm | None = None) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
adata = self.samples[i]
pts = self._coords[i]
scale = self._unit_scales[i]
grid_i = self._grid_shapes[i]
spacing_i = self._spacings[i]
X_src = adata.X if self._layer is None else adata.layers[self._layer]
n_genes = len(self.gene_names)
n_spots = X_src.shape[0]
if sp.issparse(X_src):
dc = np.asarray(X_src.mean(axis=0)).ravel()
nnz_per = np.asarray((X_src != 0).sum(axis=0)).ravel()
X_csc = X_src.tocsc()
X_dense = None
else:
X_dense = np.asarray(X_src, dtype=np.float64)
dc = X_dense.mean(axis=0)
nnz_per = (X_dense != 0).sum(axis=0)
X_csc = None
presence_i = (nnz_per / max(n_spots, 1)) >= self._presence_threshold
ny, nx = grid_i
spec_stack = np.empty((n_genes, ny, nx), dtype=np.float64)
for start in range(0, n_genes, chunk_size):
stop = min(start + chunk_size, n_genes)
cols = slice(start, stop)
if X_csc is not None:
block = np.asarray(X_csc[:, cols].toarray(), dtype=np.float64)
else:
block = X_dense[:, cols].astype(np.float64, copy=True)
# Per-gene mean centering: removes the DC bin and prevents
# per-sample mean-shift leakage into low-frequency bins. The
# raw DC scalars are preserved on ``self.dc_`` for the
# complementary :meth:`test_diff_expr` path.
block -= dc[None, cols]
p_chunk = power_spectrum_2d_nufft(
pts,
block,
grid_shape=grid_i,
spacing=spacing_i,
unit_scale=scale,
eps=self._nufft_eps,
center_coords=True,
)
spec_stack[start:stop] = np.moveaxis(p_chunk, -1, 0)
if pbar is not None:
pbar.update(1)
return spec_stack, dc, presence_i
return _run_per_sample(
_one,
n_samples_total,
n_chunks_per_sample=int(np.ceil(len(self.gene_names) / chunk_size)),
desc="NUFFT spectra (per-gene chunks)",
n_jobs=n_jobs,
progress=progress,
)
# ------------------------------------------------------------------
def _covariate_features_from_keys( # noqa: C901 — per-key obs/var dispatch + per-sample loop
self, keys: Sequence[str]
) -> list[np.ndarray]:
"""Per-spot column lookup → per-sample covariate features.
Each ``key`` is resolved against the first sample as either:
- an ``adata.obs`` column (per-spot scalar — typical for
deconvolution outputs, region labels, depth proxies); or
- an entry in ``adata.var_names`` (treats that gene's
per-spot expression as the covariate — useful for
regressing out a housekeeping gene's spatial pattern).
Resolution prefers ``obs`` when a name appears in both. Every
subsequent sample must resolve each key to the same source as
the first sample (i.e., a key is "obs everywhere" or "var
everywhere") — anything else is treated as a schema mismatch.
For each sample the resolved per-spot vectors are stacked into
an ``(n_obs, n_covariates)`` block, mean-centred per column,
and NUFFTed directly onto the sample's k-grid. The 2-D spectra
are then radial-binned with the same edges as the gene panel.
Raises
------
KeyError
If a key is missing from both ``adata.obs.columns`` and
``adata.var_names`` in any sample, or if a key resolves to
different sources across samples.
ValueError
If an obs column cannot be cast to float (e.g., string
categoricals — encode them first).
"""
from quadsv.kernels.nufft import power_spectrum_2d_nufft
keys = list(keys)
# Classify each key once against the first sample; require all
# later samples to agree on the source.
first = self.samples[0]
sources: dict[str, str] = {}
for k in keys:
in_obs = k in first.obs.columns
in_var = k in first.var_names
if not (in_obs or in_var):
raise KeyError(
f"covariate key {k!r} is in neither obs.columns nor "
f"var_names of sample 0. Available obs (first 10): "
f"{list(first.obs.columns)[:10]}; available var_names "
f"(first 10): {list(first.var_names)[:10]}."
)
sources[k] = "obs" if in_obs else "var"
out: list[np.ndarray] = []
for i, adata in enumerate(self.samples):
cols: list[np.ndarray] = []
for k in keys:
src = sources[k]
if src == "obs":
if k not in adata.obs.columns:
raise KeyError(
f"sample {i} resolves covariate {k!r} differently from "
f"sample 0 (sample 0 → obs, sample {i} → not in obs)."
)
try:
cols.append(np.asarray(adata.obs[k], dtype=np.float64))
except (TypeError, ValueError) as exc:
raise ValueError(
f"sample {i} obs[{k!r}] cannot be cast to float "
f"({type(exc).__name__}); encode categoricals before passing."
) from exc
else: # var_names path
if k not in adata.var_names:
raise KeyError(
f"sample {i} resolves covariate {k!r} differently from "
f"sample 0 (sample 0 → var_names, sample {i} → not in "
"var_names)."
)
idx = adata.var_names.get_loc(k)
X_src = adata.X if self._layer is None else adata.layers[self._layer]
col = X_src[:, idx]
if sp.issparse(col):
col = col.toarray()
cols.append(np.asarray(col, dtype=np.float64).ravel())
block = np.column_stack(cols) # (n_obs, n_cov)
# Mean-centre each column, matching the gene panel's per-column DC removal.
block = block - block.mean(axis=0, keepdims=True)
p = power_spectrum_2d_nufft(
self._coords[i],
block,
grid_shape=self._grid_shapes[i],
spacing=self._spacings[i],
unit_scale=self._unit_scales[i],
eps=self._nufft_eps,
center_coords=True,
)
# power_spectrum_2d_nufft returns (ny, nx, M) for multi-column values.
cov_2d = np.moveaxis(p, -1, 0) # (n_cov, ny, nx)
ny, nx = self._grid_shapes[i]
if self.feature_mode == "radial":
cov_feat = radial_bin_spectrum(
cov_2d,
grid_shape=(ny, nx),
n_bins=self._n_radial_bins,
fft_solver=self._spectrum_fft_solver,
spacing=self._spacings[i],
edges=self.freq_edges,
)
else:
k_max = min(self._n_radial_bins, ny // 2, nx // 2)
low = (
cov_2d[:, :k_max, :k_max] if cov_2d.shape[-1] > k_max else cov_2d[:, :k_max, :]
)
cov_feat = low.reshape(low.shape[0], -1)
out.append(cov_feat)
return out
# ---------------------------------------------------------------------------
# shared per-sample runner
# ---------------------------------------------------------------------------
def _run_per_sample(
worker: Any,
n_samples_total: int,
*,
n_chunks_per_sample: int,
desc: str,
n_jobs: int,
progress: bool,
) -> tuple[list[np.ndarray], np.ndarray, np.ndarray]:
"""Invoke ``worker(i, pbar)`` for each sample with a shared tqdm bar.
Used by :class:`ComparatorIrregular` where each sample is split into
multiple per-gene-chunk tqdm ticks.
"""
raw_2d: list[np.ndarray | None] = [None] * n_samples_total
dc_list: list[np.ndarray | None] = [None] * n_samples_total
pres_list: list[np.ndarray | None] = [None] * n_samples_total
run_sequential = progress or n_jobs == 1
if run_sequential:
n_total = n_samples_total * n_chunks_per_sample
pbar: tqdm | None = tqdm(total=n_total, desc=desc) if progress else None
for i in range(n_samples_total):
if pbar is not None:
pbar.set_postfix_str(f"sample {i + 1}/{n_samples_total}")
r0, r1, r2 = worker(i, pbar)
raw_2d[i] = r0
dc_list[i] = r1
pres_list[i] = r2
if pbar is not None:
pbar.close()
else:
results = Parallel(n_jobs=n_jobs, prefer="threads")(
delayed(worker)(i, None) for i in range(n_samples_total)
)
for i, r in enumerate(results):
raw_2d[i], dc_list[i], pres_list[i] = r
dc = np.stack([np.asarray(x) for x in dc_list], axis=0)
presence = np.stack([np.asarray(x) for x in pres_list], axis=0)
return [np.asarray(x) for x in raw_2d], dc, presence
# ---------------------------------------------------------------------------
# Gene-name resolution helpers
# ---------------------------------------------------------------------------
def _resolve_anndata_gene_names(
samples: list[Any],
gene_names: Sequence[str] | None,
*,
layer: str | None,
) -> list[str]:
first = samples[0]
if gene_names is None:
gene_names = list(first.var_names)
for i, s in enumerate(samples):
if list(s.var_names) != list(gene_names):
raise ValueError(
f"sample {i} has var_names that do not match the reference "
"(all AnnData samples must share the same gene axis)."
)
if layer is not None and layer not in s.layers:
raise KeyError(f"sample {i} is missing layer '{layer}'.")
return list(gene_names)