Source code for quadsv.comparators.grid

"""
:class:`ComparatorGrid` — cross-sample pattern comparison on a list of
:class:`spatialdata.SpatialData` (regular rasterized bins, FFT
backend).
"""

from __future__ import annotations

import logging
import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
from joblib import Parallel, delayed
from tqdm.auto import tqdm

# Suppress known deprecation warnings from SpatialData dependencies BEFORE importing them.
warnings.filterwarnings("ignore", category=FutureWarning, message=".*legacy Dask DataFrame.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources is deprecated.*")

import spatialdata as sd

from quadsv.comparators.base import (
    _ComparatorBase,
    _validate_common,
)
from quadsv.comparators.multisample import compute_sample_spectrum

__all__ = ["ComparatorGrid"]

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# ComparatorGrid — SpatialData / regular bins via rasterize_bins
# ---------------------------------------------------------------------------


[docs] class ComparatorGrid(_ComparatorBase): """ Cross-sample pattern comparison on regular bins via FFT + SpatialData. Accepts a list of :class:`spatialdata.SpatialData` (one per sample). For each sample, :func:`spatialdata.rasterize_bins` turns the designated bin shape + table into a dense ``(n_genes, ny, nx)`` image, which is then fed to the batched 2D FFT. All samples are expected to share the same rasterization schema (``bins`` / ``table_name`` / ``col_key`` / ``row_key`` / ``value_key``) — this mirrors :class:`~quadsv.DetectorGrid`. Parameters ---------- samples : sequence of :class:`spatialdata.SpatialData` bins : str SpatialElement key for the bin shapes in each ``sdata``. table_name : str Table key in each ``sdata.tables``. col_key, row_key : str Column / row-index columns in the table's ``.obs``. value_key : str, optional Expression column in ``.obs``; defaults to ``None`` (rasterizes counts / presence directly off ``.X``). gene_names : sequence of str, optional If None, inferred from the first sample's table. All samples must share ``var_names``. fft_chunk_size : int, default 256 Genes per batched ``scipy.fft`` call on the rasterized block. Keeps transient memory bounded at ``O(ny · nx · chunk · 8 B)``. The raster itself is still built once per sample (full ``(n_genes, ny, nx)`` footprint is unavoidable on SpatialData). Other Parameters ---------------- feature_mode, n_radial_bins, fft_solver, workers, freq_edges, presence_threshold See :class:`_ComparatorBase`. Notes ----- The comparator carries no design / contrast state — supply the cross-sample contrast directly to :meth:`test_diff_freq` / :meth:`test_diff_expr`. A single fitted comparator can therefore serve any number of unrelated comparisons on the same spectra. """ def __init__( self, samples: Sequence[Any], *, bins: str, table_name: str, col_key: str, row_key: str, value_key: str | None = None, gene_names: Sequence[str] | None = None, feature_mode: str = "radial", n_radial_bins: int = 30, fft_solver: str = "rfft2", workers: int | None = None, spacing: tuple[float, float] | None = None, freq_edges: np.ndarray | None = None, presence_threshold: float = 0.0, fft_chunk_size: int = 256, ) -> None: fft_solver = _validate_common(feature_mode, fft_solver, presence_threshold) samples_list = list(samples) if len(samples_list) == 0: raise ValueError("samples must be a non-empty list.") for i, s in enumerate(samples_list): if not isinstance(s, sd.SpatialData): raise TypeError( f"sample {i} is {type(s).__name__}, expected spatialdata.SpatialData." ) resolved = _resolve_spatialdata_gene_names(samples_list, gene_names, table_name)
[docs] self.samples = samples_list
[docs] self.gene_names = list(resolved)
[docs] self.feature_mode = feature_mode
[docs] self.freq_edges = None if freq_edges is None else np.asarray(freq_edges, dtype=float)
# Private (internal-config) state. self._n_radial_bins = int(n_radial_bins) self._fft_solver = fft_solver self._workers = workers self._presence_threshold = float(presence_threshold) self._fft_chunk_size = max(1, int(fft_chunk_size)) self._spectrum_fft_solver = fft_solver self._bins = bins self._table_name = table_name self._col_key = col_key self._row_key = row_key self._value_key = value_key # Grid shape is determined per-sample inside compute_spectra by rasterize_bins # (the raster's .shape[-2:] carries it). Record the placeholder now # and fill in during _compute_spectra; spacing is always (1.0, 1.0) # because rasterize_bins outputs one pixel per bin — users can # override via the spacing kwarg when the bins encode a physical # pitch. self._spacing_override = None if spacing is None else (float(spacing[0]), float(spacing[1])) self._grid_shapes = [] # populated by _compute_spectra self._spacings = None # populated alongside # ------------------------------------------------------------------ def _rasterize_one(self, sdata: Any) -> np.ndarray: """Wrap :func:`spatialdata.rasterize_bins`. Returns a ``(n_genes, ny, nx)`` float array in :attr:`gene_names` order. """ from quadsv._rasterize import rasterize_table table = sdata.tables[self._table_name] img = rasterize_table( sdata, bins=self._bins, table_name=self._table_name, col_key=self._col_key, row_key=self._row_key, value_key=self._value_key, return_region_as_labels=False, ) arr = np.asarray(img.data if hasattr(img, "data") else img, dtype=np.float64) # Expected shape: (n_genes_in_table, ny, nx). Reindex gene axis to # self.gene_names (which was validated to match table var_names). if arr.ndim != 3: raise ValueError( f"rasterize_bins returned shape {arr.shape}, expected (n_genes, ny, nx)." ) table_names = list(table.var_names) if table_names == list(self.gene_names): return arr idx = np.asarray([table_names.index(g) for g in self.gene_names], dtype=int) return arr[idx] def _compute_spectra( self, n_jobs: int, progress: bool ) -> tuple[list[np.ndarray], np.ndarray, np.ndarray]: chunk = self._fft_chunk_size def _one(i: int, pbar: tqdm | None = None) -> tuple[np.ndarray, np.ndarray, np.ndarray]: raster = self._rasterize_one(self.samples[i]) n_genes, ny, nx = raster.shape # Lock grid shape + spacing from the first-seen rasterization. self._grid_shapes_local_i = (ny, nx) frac_nonzero = (raster != 0).reshape(n_genes, -1).mean(axis=1) presence_i = frac_nonzero >= self._presence_threshold dc = raster.mean(axis=(1, 2)) expected_kx = nx if self._spectrum_fft_solver == "fft2" else nx // 2 + 1 spec_stack = np.empty((n_genes, ny, expected_kx), dtype=np.float64) for start in range(0, n_genes, chunk): stop = min(start + chunk, n_genes) # Reuse the shared helper; it always mean-centres each gene # before the FFT. spec_chunk = compute_sample_spectrum( raster[start:stop], fft_solver=self._spectrum_fft_solver, workers=self._workers, return_dc=False, ) spec_stack[start:stop] = spec_chunk if pbar is not None: pbar.update(1) return spec_stack, dc, presence_i # Run the per-sample loop and collect the grid shapes (only known # after rasterize_bins returns). n_samples_total = len(self.samples) raw_2d: list[np.ndarray | None] = [None] * n_samples_total dc_list: list[np.ndarray | None] = [None] * n_samples_total pres_list: list[np.ndarray | None] = [None] * n_samples_total grids: list[tuple[int, int]] = [] run_sequential = progress or n_jobs == 1 if run_sequential: n_chunks_total = sum( int(np.ceil(len(self.gene_names) / chunk)) for _ in range(n_samples_total) ) pbar: tqdm | None = ( tqdm(total=n_chunks_total, desc="FFT spectra (per-gene chunks)") if progress else None ) for i in range(n_samples_total): if pbar is not None: pbar.set_postfix_str(f"sample {i + 1}/{n_samples_total}") r0, r1, r2 = _one(i, pbar=pbar) raw_2d[i] = r0 dc_list[i] = r1 pres_list[i] = r2 grids.append(self._grid_shapes_local_i) if pbar is not None: pbar.close() else: results = Parallel(n_jobs=n_jobs, prefer="threads")( delayed(_one)(i) for i in range(n_samples_total) ) for i, r in enumerate(results): raw_2d[i], dc_list[i], pres_list[i] = r # When running via joblib the `self._grid_shapes_local_i` # side-channel isn't reliable — infer from the returned spec. grids.append(raw_2d[i].shape[-2:]) # Record per-sample grids + spacings for downstream radial binning. self._grid_shapes = grids if self._spacing_override is not None: self._spacings = [self._spacing_override] * len(grids) else: self._spacings = [(1.0, 1.0)] * len(grids) del self._grid_shapes_local_i # tidy attr soup dc = np.stack([np.asarray(x) for x in dc_list], axis=0) presence = np.stack([np.asarray(x) for x in pres_list], axis=0) return [np.asarray(x) for x in raw_2d], dc, presence # ------------------------------------------------------------------ def _covariate_features_from_keys(self, keys: Sequence[str]) -> list[np.ndarray]: """Forward ``keys`` to :func:`spatialdata.rasterize_bins` as ``value_key`` and turn each per-sample raster into covariate features. ``keys`` may name any combination of ``.obs`` columns and ``var_names`` in the comparator's ``table_name`` table — ``rasterize_bins`` accepts both. Each sample's resulting ``(n_keys, ny, nx)`` raster is then funneled through the same spectrum + radial-binning pipeline as the gene panel via :meth:`_covariate_features_from_array`. """ from quadsv._rasterize import rasterize_table keys = list(keys) out: list[np.ndarray] = [] for i, sdata in enumerate(self.samples): img = rasterize_table( sdata, bins=self._bins, table_name=self._table_name, col_key=self._col_key, row_key=self._row_key, value_key=keys, return_region_as_labels=False, ) arr = np.asarray(img.data if hasattr(img, "data") else img, dtype=np.float64) if arr.ndim == 2: # Single key → rasterize_bins drops the leading axis. Re-insert. arr = arr[None, :, :] if arr.ndim != 3: raise ValueError( f"sample {i} covariate raster has shape {arr.shape}; " "expected (n_keys, ny, nx) from rasterize_bins." ) out.append(self._covariate_features_from_array(arr, sample_index=i)) return out
def _resolve_spatialdata_gene_names( samples: list[Any], gene_names: Sequence[str] | None, table_name: str, ) -> list[str]: first = samples[0] if table_name not in first.tables: raise KeyError( f"sample 0 has no table '{table_name}'; available: {list(first.tables.keys())}." ) if gene_names is None: gene_names = list(first.tables[table_name].var_names) for i, s in enumerate(samples): if table_name not in s.tables: raise KeyError(f"sample {i} has no table '{table_name}'.") tbl_names = list(s.tables[table_name].var_names) if tbl_names != list(gene_names): raise ValueError(f"sample {i}'s table has var_names that do not match the reference.") return list(gene_names)