from __future__ import annotations
import warnings
from typing import Any
import numpy as np
import scipy.fft
import scipy.sparse as sp
from scipy.special import gamma, kv
from scipy.stats import chi2, norm
from quadsv.kernels.base import Kernel
__all__ = ["FFTKernel", "power_spectrum_2d"]
[docs]
def power_spectrum_2d(
x: np.ndarray,
fft_solver: str = "fft2",
workers: int | None = None,
) -> np.ndarray:
"""
Compute the 2D power spectrum :math:`|\\hat{x}(k)|^2` of one or more grid signals.
The result is *translation-invariant*: shifting the input image leaves the power
spectrum unchanged. This makes the spectrum a natural alignment-free representation
of a spatial pattern. Use :func:`quadsv.comparators.multisample.radial_bin_spectrum` to
further reduce the 2D spectrum to a 1D radial-binned vector that is also
rotation-invariant.
Parameters
----------
x : np.ndarray
Grid signal of shape ``(ny, nx)`` for a single feature, or ``(ny, nx, M)``
for ``M`` stacked features sharing the grid.
fft_solver : {'fft2', 'rfft2'}, default 'fft2'
FFT routine. ``'rfft2'`` returns the half-spectrum of shape
``(ny, nx // 2 + 1)`` and roughly halves memory.
workers : int, optional
Number of parallel workers forwarded to :mod:`scipy.fft`. ``None`` uses the
SciPy default.
Returns
-------
np.ndarray
Power spectrum. Shape ``(ny, n_kx)`` if input was 2D, or ``(ny, n_kx, M)``
if input was 3D, where ``n_kx = nx`` for ``fft2`` and ``nx // 2 + 1`` for
``rfft2``. Layout matches the corresponding :mod:`scipy.fft` routine
(zero-frequency bin at ``[0, 0]``, no fftshift applied).
Raises
------
ValueError
If ``fft_solver`` is not one of ``'fft2'`` or ``'rfft2'``.
Examples
--------
>>> img = np.random.randn(32, 32)
>>> P = power_spectrum_2d(img, fft_solver='rfft2')
>>> P.shape
(32, 17)
"""
if fft_solver not in ("fft2", "rfft2"):
raise ValueError(f"fft_solver must be 'fft2' or 'rfft2', got '{fft_solver}'")
squeeze = x.ndim == 2
if squeeze:
x = x[..., np.newaxis]
if fft_solver == "fft2":
x_hat = scipy.fft.fft2(x, axes=(0, 1), workers=workers)
else:
x_hat = scipy.fft.rfft2(x, axes=(0, 1), workers=workers)
power = np.abs(x_hat) ** 2
if squeeze:
power = power[..., 0]
return power
[docs]
class FFTKernel(Kernel):
"""
FFT-accelerated spatial kernel for dense grid data.
Operates on evenly-spaced grid data (raster data) with spectral decomposition
via FFT under periodic (torus) boundary conditions.
Attributes
----------
ny, nx : int
Grid dimensions (number of rows and columns).
n_grid : int
Total number of grid points (``ny * nx``).
topology : {'square', 'hex'}
Grid topology. ``'hex'`` mirrors 10x Visium hexagonal layouts.
method : str
Kernel method (``'gaussian'``, ``'matern'``, ``'moran'``, ``'graph_laplacian'``,
``'car'``).
params : dict
Resolved kernel parameters (e.g. ``bandwidth``, ``nu``, ``neighbor_degree``,
``rho``) after defaults are merged with user overrides.
fft_solver : {'fft2', 'rfft2'}
FFT routine in use. ``'rfft2'`` stores roughly half the spectrum.
n_rfft : int
Length of the flattened spectrum: ``ny * nx`` for ``fft2`` and
``ny * (nx // 2 + 1)`` for ``rfft2``.
workers : int or None
Number of parallel workers forwarded to :mod:`scipy.fft`.
spectrum : np.ndarray
Flattened (row-major) eigenvalues of the kernel matrix, shape ``(n_rfft,)``.
Eagerly computed in ``__init__``. See :meth:`eigenvalues` for a sorted /
full-FFT-layout accessor.
"""
_available_kernels = ["gaussian", "matern", "moran", "graph_laplacian", "car"]
def __init__( # noqa: C901
self,
shape: tuple[int, int],
spacing: tuple[float, float] = (1.0, 1.0),
topology: str = "square",
method: str = "matern",
workers: int | None = None,
fft_solver: str = "fft2",
*,
centering: bool = True,
**kwargs,
) -> None:
"""
Initialize FFT-accelerated spatial kernel for grid data.
Parameters
----------
shape : tuple of int
Grid dimensions (ny, nx).
spacing : tuple of float, default (1.0, 1.0)
Physical distance between pixels (dy, dx).
topology : {'square', 'hex'}, default 'square'
Grid topology. 'hex' is for Visium-like hexagonal layouts.
method : str, default 'matern'
Kernel method: 'gaussian', 'matern', 'moran', 'graph_laplacian', 'car'.
workers : Optional[int], default None
Number of parallel workers for fft computations.
fft_solver : {'fft2', 'rfft2'}, default 'fft2'
FFT solver to use. 'fft2' (full FFT) or 'rfft2' (real FFT, ~50% memory).
Default is 'fft2' for better compatibility and robustness on most architectures.
**kwargs : dict
Kernel parameters (bandwidth, nu, neighbor_degree, rho).
Examples
--------
>>> kernel = FFTKernel((64, 64), method='gaussian', bandwidth=2.0)
>>> kernel = FFTKernel((64, 64), topology='hex', method='matern')
"""
super().__init__(centering=centering)
ny, nx = shape
if ny < 2 or nx < 2:
raise ValueError(f"Grid dimensions must be >= 2, got ({ny}, {nx})")
"""Number of grid rows."""
"""Number of grid columns."""
self._dy, self._dx = spacing
[docs]
self.n_grid: int = self.ny * self.nx
"""Total number of grid points (``ny * nx``)."""
[docs]
self.n: int = self.n_grid
"""Alias for ``n_grid`` to satisfy the :class:`~quadsv.kernels.Kernel` interface."""
# FFT solver selection
if fft_solver not in ("fft2", "rfft2"):
raise ValueError(f"fft_solver must be 'fft2' or 'rfft2', got '{fft_solver}'")
[docs]
self.fft_solver: str = fft_solver
"""FFT routine in use (``'fft2'`` or ``'rfft2'``)."""
[docs]
self.n_rfft: int = (
self.ny * self.nx if fft_solver == "fft2" else self.ny * (self.nx // 2 + 1)
)
"""Length of the flattened spectrum buffer (``ny*nx`` for ``fft2``, ``ny*(nx//2+1)`` for ``rfft2``)."""
# Sanity Checks
if topology not in ("square", "hex"):
raise ValueError(f"topology must be 'square' or 'hex', got '{topology}'")
if method not in self._available_kernels:
raise ValueError(f"method must be one of {self._available_kernels}, got '{method}'")
[docs]
self.topology: str = topology
"""Grid topology (``'square'`` or ``'hex'``)."""
[docs]
self.method: str = method
"""Kernel method name."""
# Update kernel parameters from defaults. Graph kernels accept an
# additional ``k_neighbors`` as a convenience (k-NN semantic): it's
# converted to the closest ``neighbor_degree`` (FFT-ring semantic)
# based on the grid topology — see _k_neighbors_to_degree below.
params = self._get_default_params(method).copy()
k_neighbors_user = None
if kwargs:
for key, value in kwargs.items():
if key == "k_neighbors" and method in ("moran", "graph_laplacian", "car"):
k_neighbors_user = value
elif key in params:
params[key] = value
else:
raise ValueError(f"Unknown parameter '{key}' for method '{method}'")
[docs]
self.params: dict = params
"""Resolved kernel parameters after defaults are merged with user overrides."""
[docs]
self.workers: int | None = workers
"""Number of parallel workers forwarded to :mod:`scipy.fft`, or ``None`` for the library default."""
# 1. Precompute Distances
# For Periodic: Distances wrap around (min(d, L-d)).
if self.topology == "hex":
self._min_dist_sq = self._precompute_hex_torus()
else:
self._min_dist_sq = self._precompute_square_dists()
# Resolve k_neighbors → neighbor_degree now that the distance grid is built.
if k_neighbors_user is not None:
if "neighbor_degree" in kwargs:
raise ValueError(
"Pass either 'k_neighbors' (k-NN semantic) or "
"'neighbor_degree' (FFT-ring semantic), not both."
)
self.params["neighbor_degree"] = self._k_neighbors_to_degree(k_neighbors_user)
# 2. Precompute Kernel spectrum
[docs]
self.spectrum: np.ndarray = self._compute_eigenvalues()
"""Flattened (row-major) eigenvalues of the kernel matrix, shape ``(n_rfft,)``."""
def _unique_ring_distances(self) -> np.ndarray:
"""Tolerance-grouped unique squared distances on the grid.
Groups numerically-close values into a single "ring" so hex and other
irrational-coordinate topologies report physical shells consistently.
Returns sorted ascending, starting with 0 (self). Internal helper.
"""
flat = np.sort(self._min_dist_sq.ravel())
tol = 1e-6 * max(1.0, float(flat[-1]))
# Take the first element of each tolerance-gap-separated group.
diffs = np.diff(flat)
keep = np.concatenate([[True], diffs > tol])
return flat[keep]
def _k_neighbors_to_degree(self, k_neighbors: int) -> int:
"""Translate a k-NN-style ``k_neighbors`` into an FFT-ring ``neighbor_degree``.
Returns the smallest ``neighbor_degree`` whose cumulative count of
grid cells (excluding self) is ≥ ``k_neighbors``. Topology-aware via
:meth:`_unique_ring_distances`:
- Square: k=4 → 1 (N/S/E/W), k=8 → 2 (+diagonals), k=12 → 3.
- Hex: k=6 → 1, k=12 → 2, k=18 → 3.
"""
if k_neighbors < 1:
raise ValueError(f"k_neighbors must be ≥ 1, got {k_neighbors}")
unique_dists = self._unique_ring_distances()
tol = 1e-6 * max(1.0, float(unique_dists[-1]))
for deg_order in range(1, len(unique_dists)):
cutoff_sq = unique_dists[deg_order]
count = int((self._min_dist_sq <= cutoff_sq + tol).sum()) - 1
if count >= k_neighbors:
return deg_order
return len(unique_dists) - 1
def _format_params(self) -> str:
"""Format kernel params safely without dumping large arrays/matrices."""
if not self.params:
return "None"
parts = []
for k, v in self.params.items():
try:
if isinstance(v, np.ndarray):
parts.append(f"{k}=array(shape={v.shape}, dtype={v.dtype})")
elif sp.issparse(v):
parts.append(f"{k}=sparse(shape={v.shape}, nnz={v.nnz})")
else:
parts.append(f"{k}={v}")
except Exception:
parts.append(f"{k}=?")
return ", ".join(parts)
def __repr__(self) -> str:
"""
Return a detailed, machine-readable representation of the FFTKernel.
Returns
-------
str
String representation in angle-bracket format.
"""
spectrum_info = (
f"spectrum shape={self.spectrum.shape}"
if self.spectrum is not None
else "spectrum=None"
)
return (
f"<FFTKernel method={self.method} shape=({self.ny}, {self.nx}) topology={self.topology} "
f"fft_solver={self.fft_solver} {spectrum_info} params={{ {self._format_params()} }}>"
)
def __str__(self) -> str:
"""
Return a human-friendly, multi-line representation of the FFTKernel.
Returns
-------
str
Multi-line string summary.
"""
lines = [
"FFTKernel",
f"- Method: {self.method}",
f"- Grid shape: ({self.ny}, {self.nx})",
f"- Topology: {self.topology}",
f"- Spacing: ({self._dy}, {self._dx})",
f"- FFT Solver: {self.fft_solver}",
]
if self.spectrum is not None:
lines.append(
f"- Spectrum: shape={self.spectrum.shape}, min={np.min(self.spectrum):.4g}, max={np.max(self.spectrum):.4g}"
)
else:
lines.append("- Spectrum: None")
lines.append(f"- Params: {self._format_params()}")
return "\n".join(lines)
def _get_default_params(self, method: str) -> dict[str, Any]:
"""
Returns default parameters for specific kernel methods.
Parameters
----------
method : str
Kernel method name. Should be one of _available_kernels.
Returns
-------
dict[str, Any]
Method defaults: bandwidth (gaussian/matern), nu (matern), neighbor_degree (moran/graph_laplacian/car), rho (car).
"""
method_defaults = {
"gaussian": {"bandwidth": 2.0},
"matern": {"nu": 1.5, "bandwidth": 2.0},
"moran": {"neighbor_degree": 1},
"graph_laplacian": {"neighbor_degree": 1},
"car": {"rho": 0.9, "neighbor_degree": 1},
}
return method_defaults.get(method, {})
def _precompute_square_dists(self):
"""Computes wrap-around torus distances from (0,0) to (y,x)."""
y = np.arange(self.ny) * self._dy
x = np.arange(self.nx) * self._dx
# Wrap-around distance for periodic boundaries
y = np.minimum(y, (self.ny * self._dy) - y)
x = np.minimum(x, (self.nx * self._dx) - x)
yy, xx = np.meshgrid(y, x, indexing="ij")
return yy**2 + xx**2
def _precompute_hex_torus(self):
"""Squared torus distances on a hexagonal grid (Visium convention).
Spot ``(r, c)`` lies at physical ``(y, x) = (r * sqrt(3)/2, c + 0.5 * (r%2))``
in units of the center-to-center horizontal step — i.e., odd rows are shifted
half a step in +x, matching the 10x Visium ``array_row`` / ``array_col``
layout.
Returns an ``(ny, nx)`` array consistent with the ``(ny, nx)`` signal shape
expected by :meth:`xtKx`. (The previous implementation returned ``(nx, ny)``
and scrambled the spectrum for non-square grids, silently breaking anisotropic
signals on any real Visium slide — Visium is never square. Tests covered only
square hex grids so the bug was invisible. Fixed.)
Periodicity in the y direction is well-defined only when ``ny`` is even (so
the row-parity shift is preserved under wrap-around); callers feeding odd
``ny`` will get a near-periodic but slightly off torus and a warning is
emitted.
"""
if self.ny % 2 != 0:
warnings.warn(
f"Hex topology expects an even number of rows (ny); got ny={self.ny}. "
"Periodic boundary conditions are approximate for odd ny.",
UserWarning,
stacklevel=2,
)
r = np.arange(self.ny) # row index, first (ny) axis
c = np.arange(self.nx) # col index, second (nx) axis
rr, cc = np.meshgrid(r, c, indexing="ij") # both shape (ny, nx)
y_phys = rr * (np.sqrt(3) / 2.0)
x_phys = cc + 0.5 * (rr % 2)
coords_grid = np.stack([y_phys, x_phys], axis=-1) # (ny, nx, 2)
# Torus periods: width in x is nx, height in y is ny * sqrt(3)/2.
P_y = np.array([self.ny * (np.sqrt(3) / 2.0), 0.0])
P_x = np.array([0.0, float(self.nx)])
min_d2 = np.full((self.ny, self.nx), np.inf)
for k in (-1, 0, 1):
for m in (-1, 0, 1):
shift = k * P_y + m * P_x
shifted = coords_grid + shift.reshape(1, 1, 2)
d2 = np.sum(shifted**2, axis=-1)
min_d2 = np.minimum(min_d2, d2)
return min_d2
def _compute_eigenvalues(self): # noqa: C901
"""Spectral decomposition of the kernel using fft2 or rfft2.
Returns eigenvalues in the selected FFT layout.
"""
# --- Continuous Kernels ---
if self.method == "gaussian":
bw = self.params["bandwidth"]
K_img = np.exp(-0.5 * (self._min_dist_sq / bw**2))
if self.fft_solver == "fft2":
spectrum_2d = scipy.fft.fft2(K_img, workers=self.workers)
else:
spectrum_2d = scipy.fft.rfft2(K_img, workers=self.workers)
return np.real(spectrum_2d).ravel()
elif self.method == "matern":
bw = self.params["bandwidth"]
nu = self.params["nu"]
d = np.sqrt(self._min_dist_sq)
mask_zero = d == 0
d[mask_zero] = 1.0 # dummy value, overwritten below
factor = (np.sqrt(2 * nu) * d) / bw
K_img = (2 ** (1 - nu) / gamma(nu)) * (factor**nu) * kv(nu, factor)
K_img[mask_zero] = 1.0 # correct limit: K(x, x) = 1
if self.fft_solver == "fft2":
spectrum_2d = scipy.fft.fft2(K_img, workers=self.workers)
else:
spectrum_2d = scipy.fft.rfft2(K_img, workers=self.workers)
return np.real(spectrum_2d).ravel()
# --- Graph-based Kernels (Moran / Graph Laplacian / CAR) ---
elif self.method in ["moran", "graph_laplacian", "car"]:
degree_order = self.params["neighbor_degree"]
# Tolerance-based ring grouping: hex distances like 1.0 arise from
# sqrt(3)/2 products and split into several numerical clusters
# (e.g. 0.9999998 and 1.0000003) that are *physically the same
# shell*. Without this, degree_order=1 on hex returns only 2 cells
# instead of the full 6-neighbour ring.
unique_dists = self._unique_ring_distances()
if degree_order < len(unique_dists):
cutoff_sq = unique_dists[degree_order]
else:
cutoff_sq = unique_dists[-1]
# Inclusive cutoff with tolerance so we catch the whole ring.
tol = 1e-6 * max(1.0, float(unique_dists[-1]))
W_img = (self._min_dist_sq <= cutoff_sq + tol).astype(float)
W_img[0, 0] = 0.0
# Row-Normalization Factor
# For Periodic: Exact constant degree.
degree = np.sum(W_img)
if degree == 0:
return np.ones(self.n_grid)
# Compute Spectrum of Normalized W
if self.fft_solver == "fft2":
spectrum_2d = scipy.fft.fft2(W_img, workers=self.workers)
else:
spectrum_2d = scipy.fft.rfft2(W_img, workers=self.workers)
lam_W = np.real(spectrum_2d).ravel() / degree
if self.method == "moran":
return lam_W
elif self.method == "graph_laplacian":
return 1.0 - lam_W
elif self.method == "car":
rho = self.params["rho"]
# Cap rho to prevent singularity if rho is too close to 1
if rho >= 1.0:
warnings.warn(
f"rho={rho} >= 1.0 causes singularity in CAR kernel; clamping to 0.99",
UserWarning,
stacklevel=2,
)
rho = 0.99
return 1.0 / (1.0 - rho * lam_W)
else:
raise ValueError("Unknown method")
[docs]
def xtKx(self, x: np.ndarray) -> float | np.ndarray:
"""
Compute the quadratic form x^T K x efficiently using FFT.
Uses Parseval's theorem to compute the result in frequency domain
for O(n log n) complexity instead of O(n²).
Parameters
----------
x : np.ndarray
Input data tensor. Shape (ny, nx) for single feature or (ny, nx, M) for M features.
Returns
-------
float or np.ndarray
Quadratic form value(s). Scalar if input was 2D, shape (M,) if input was 3D.
"""
if x.ndim == 2:
x = x[..., np.newaxis]
ny, nx, M = x.shape
if ny != self.ny or nx != self.nx:
raise ValueError(
f"Data shape ({ny}, {nx}) does not match kernel ({self.ny}, {self.nx})"
)
# HKH quadratic form on z-scored input equals raw K on the centered
# input; subtracting per-feature mean is cheap on the grid.
if self.centering:
x = x - x.mean(axis=(0, 1), keepdims=True)
# Transform using selected FFT solver via the shared power-spectrum helper.
x_power = power_spectrum_2d(x, fft_solver=self.fft_solver, workers=self.workers)
if self.fft_solver == "fft2":
# Reshape spectrum for full fft2: (ny, nx, 1)
lam = self.spectrum.reshape(self.ny, self.nx, 1)
# Weighted Sum (Parseval's Theorem)
weighted_power = np.sum(x_power * lam, axis=(0, 1))
else:
# Reshape spectrum for rfft2: (ny, nx//2+1, 1)
lam = self.spectrum.reshape(self.ny, self.nx // 2 + 1, 1)
# Weighted Sum (Parseval's Theorem) with correction for rfft2
weighted = x_power * lam
weighted_power = 2.0 * np.sum(weighted, axis=(0, 1))
# Correction: Subtract the first column (fx=0) once
# because we added it twice in the line above, but it only exists once.
weighted_power -= np.sum(weighted[:, 0, :], axis=0)
# Correction: If width is even, the last column is Nyquist (fx=N/2).
# It is also unique (real-valued in full spectrum), so subtract it once.
if nx % 2 == 0:
weighted_power -= np.sum(weighted[:, -1, :], axis=0)
# FFT is unnormalized: Parseval requires 1/n normalization
Q = weighted_power / (ny * nx)
# Unwrap if M=1
return Q.item() if M == 1 else Q.ravel()
def _ones_stats(self) -> tuple[float, float]:
"""Return ``(s1, s2) = (𝟏ᵀ K 𝟏, ‖K·𝟏‖²)`` analytically from the DC mode.
On a torus the constant vector ``𝟏`` is the ``k = (0, 0)`` Fourier
mode, so ``K · 𝟏 = λ₀ · 𝟏`` where ``λ₀ = spectrum[0]`` (DC).
``s1 = λ₀ · n_grid``, ``s2 = λ₀² · n_grid``. Zero FFT work.
"""
lam0 = float(self.spectrum.ravel()[0])
return lam0 * self.n_grid, (lam0**2) * self.n_grid
[docs]
def Kx(self, x: np.ndarray) -> np.ndarray:
"""
Apply the kernel operator to ``x`` via FFT in O(n log n).
Implemented as ``K x = F^{-1}(λ · F(x))`` where ``λ`` is the full
eigenvalue spectrum on the torus. The result is returned on the
spatial grid (not the feature-axis first layout used by
:class:`NUFFTKernel`); callers that want a quadratic / bilinear form
should prefer :meth:`xtKx` / :meth:`xtKy`, which avoid the inverse
FFT by using Parseval's theorem.
Parameters
----------
x : np.ndarray
Grid signal of shape ``(ny, nx)`` or ``(ny, nx, M)``.
Returns
-------
np.ndarray
``K @ x`` with the same shape as ``x``.
"""
x = np.asarray(x)
squeeze = x.ndim == 2
if squeeze:
x = x[..., np.newaxis]
ny, nx, M = x.shape
if ny != self.ny or nx != self.nx:
raise ValueError(
f"Data shape ({ny}, {nx}) does not match kernel ({self.ny}, {self.nx})"
)
# HKH x = H · (K · (H x)). Subtract per-feature spatial mean both
# before and after the FFT round-trip. Equivalent to zeroing the
# DC coefficient directly in the frequency domain.
if self.centering:
x = x - x.mean(axis=(0, 1), keepdims=True)
if self.fft_solver == "fft2":
x_hat = scipy.fft.fft2(x, axes=(0, 1), workers=self.workers)
lam = self.spectrum.reshape(ny, nx, 1)
out = np.real(scipy.fft.ifft2(lam * x_hat, axes=(0, 1), workers=self.workers))
else:
x_hat = scipy.fft.rfft2(x, axes=(0, 1), workers=self.workers)
lam = self.spectrum.reshape(ny, nx // 2 + 1, 1)
out = scipy.fft.irfft2(lam * x_hat, s=(ny, nx), axes=(0, 1), workers=self.workers)
if self.centering:
out = out - out.mean(axis=(0, 1), keepdims=True)
return out[..., 0] if squeeze else out
[docs]
def xtKy(self, x: np.ndarray, y: np.ndarray) -> float | np.ndarray:
"""
Bilinear form ``x^T K y`` on the grid via Parseval's theorem.
Parameters
----------
x, y : np.ndarray
Grid signals of shape ``(ny, nx)`` or ``(ny, nx, M)``. Both must
have the same shape.
Returns
-------
float or np.ndarray
Scalar if inputs are 2D; shape ``(M,)`` if 3D.
"""
x = np.asarray(x)
y = np.asarray(y)
if x.shape != y.shape:
raise ValueError(f"x and y must share shape; got {x.shape} vs {y.shape}.")
squeeze = x.ndim == 2
if squeeze:
x = x[..., np.newaxis]
y = y[..., np.newaxis]
ny, nx, _ = x.shape
# x^T HKH y = (H x)^T K (H y); both sides need centering.
if self.centering:
x = x - x.mean(axis=(0, 1), keepdims=True)
y = y - y.mean(axis=(0, 1), keepdims=True)
R_sum = _spectral_cross_product(x, y, self, ny, nx)
R = R_sum / (ny * nx)
if squeeze:
return float(R.item()) if R.size == 1 else R.ravel()
return R.ravel()
[docs]
def eigenvalues(self, k: int | None = None, return_full_layout: bool = False) -> np.ndarray:
"""
Eigenvalues of the kernel matrix.
When ``self.centering`` is True (default) the ``k=(0, 0)`` DC
component is zeroed before returning — this is exactly the
spectrum of ``HKH`` on a torus, since the constant vector ``𝟏``
is the DC Fourier mode. Set ``centering=False`` at construction
to recover the raw ``K`` spectrum.
Parameters
----------
k : int, optional
Number of largest eigenvalues to return. If None, returns all.
return_full_layout : bool, default False
Only for ``fft_solver='rfft2'``. If True, returns eigenvalues
in full FFT layout (ny, nx) flattened.
"""
if self.spectrum is None:
self.spectrum = self._compute_eigenvalues()
# Resolve the raw spectrum array (respecting rfft2 vs fft2).
if (self.fft_solver == "fft2") or (not return_full_layout):
spec = self.spectrum
else: # fft_solver == 'rfft2' and return_full_layout=True
# Convert rfft2 layout to full FFT layout.
full_fft = np.zeros((self.ny, self.nx), dtype=self.spectrum.dtype)
rfft_size = self.nx // 2 + 1
full_fft[:, :rfft_size] = self.spectrum.reshape(self.ny, rfft_size)
for i in range(self.ny):
for j in range(1, rfft_size - 1):
full_fft[i, self.nx - j] = full_fft[i, j].conj()
spec = full_fft.ravel()
if self.centering:
# Drop the constant-mode eigenvalue (DC entry).
spec = spec.copy()
spec[0] = 0.0
if k is None:
return spec
idx = np.argsort(-spec)[:k]
return spec[idx]
[docs]
def trace(self) -> float:
"""``trace(K)`` (raw) or ``trace(HKH)`` (centered).
Closed-form ``Σ_k λ(k)`` — FFT diagonalizes ``K`` in the
Fourier basis, so the trace is an ``O(n)`` sum over the
spectrum. No stochastic path.
"""
return float(np.sum(self.eigenvalues(return_full_layout=True)))
[docs]
def square_trace(self) -> float:
"""``trace(K²)`` (raw) or ``trace((HKH)²)`` (centered).
Closed-form ``Σ_k λ(k)²`` — FFT diagonalization gives the
spectrum directly.
"""
return float(np.sum(self.eigenvalues(return_full_layout=True) ** 2))
def _q_test_fft( # noqa: C901
Xn: np.ndarray,
kernel: FFTKernel,
null_params: dict | None = None,
return_pval: bool = True,
is_standardized: bool = False,
) -> float | np.ndarray | tuple[float, float] | tuple[np.ndarray, np.ndarray]:
"""
FFT-accelerated spatial Q-test for grid data.
Tests whether a spatial variable exhibits significant clustering or dispersion
using FFT-based spectral decomposition. Parseval's theorem reduces the
quadratic form to an elementwise spectral weighting.
Parameters
----------
Xn : np.ndarray
Input data tensor. Shape (ny, nx) for single feature or (ny, nx, M) for M features.
Order follows kernel dimensions. Will be automatically reshaped to 3D if 2D.
kernel : FFTKernel
Pre-constructed FFT kernel object for grid data.
null_params : dict, optional
Pre-computed null distribution parameters from
:func:`quadsv.statistics.compute_null_params`. When supplied, the
cached ``eigenvalues`` / ``mean_Q`` / ``var_Q`` entries are reused
in the p-value stage to avoid recomputing the spectrum on every
call — useful when running the same kernel against many features
(e.g., in :class:`quadsv.DetectorGrid`). If None, the
spectrum and moments are computed on the fly.
return_pval : bool, default True
If True, returns (Q, pval) tuple; if False, returns Q only.
is_standardized : bool, default False
If True, skips Z-score standardization internally. Otherwise standardizes
per-feature (mean 0, std 1) across spatial dimensions.
Returns
-------
Q : float or np.ndarray
Test statistic. Scalar if input was 2D; array of shape (M,) if 3D.
pval : float or np.ndarray, optional
Tail probability under null hypothesis. Only returned if return_pval=True.
Uses Liu's method for most kernels; Normal approximation for Moran's I.
Raises
------
ValueError
If ``Xn`` spatial dimensions don't match kernel shape ``(ny, nx)``.
Notes
-----
Under H₀: data is spatially independent.
Under H₁: mean-shift present.
Computationally: ``Q = zᵀ K z`` where ``z`` is standardized data.
Uses FFT via Parseval's theorem to compute :math:`Q = \\sum_{i,j} \\lambda_{i,j} Z^2_{i,j}`
in O(n' log n') time instead of O(n'³) dense methods.
For Moran's I kernel (which has negative eigenvalues), uses Normal approximation
based on asymptotic theory. For other kernels, uses Liu's chi-squared mixture approximation.
Examples
--------
>>> ny, nx = 32, 32
>>> kernel = FFTKernel((ny, nx), method='gaussian', bandwidth=1.0)
>>> data = np.random.randn(ny, nx)
>>> Q, pval = spatial_q_test(data, kernel)
"""
Xn = np.asarray(Xn).astype(float)
if Xn.ndim == 2:
Xn = Xn[..., np.newaxis]
ny, nx, M = Xn.shape
if ny != kernel.ny or nx != kernel.nx:
raise ValueError(
f"Data shape ({ny}, {nx}) does not match kernel ({kernel.ny}, {kernel.nx})"
)
# 1. Standardization (Z-score across spatial dimensions)
if is_standardized:
z = Xn
else:
# Mean/Std per feature slice
means = np.mean(Xn, axis=(0, 1), keepdims=True)
stds = np.std(Xn, axis=(0, 1), keepdims=True, ddof=1)
# Handle constant features (std=0)
# Create result array
z = np.zeros_like(Xn)
# Mask where std > 0 (shape 1,1,M broadcastable)
valid = stds > 1e-12
# Safe division
np.divide(Xn - means, stds, out=z, where=valid)
# 2. Compute Q statistic: z^T K z
# Helper returns (M,) array or scalar if input was 2D
Q = kernel.xtKx(z)
if not return_pval:
return Q
# 3. P-value approximation. Dispatch on the user-selected null method
# (``null_params['method']``) mirroring the MatrixKernel path in
# :func:`quadsv.statistics.spatial_q_test`: any of 'clt' / 'welch' / 'liu'.
# Default: CLT for Moran (indefinite K → Welch/Liu degenerate),
# Liu for everything else. When `null_params` is supplied the caller's
# cached moments are reused so we don't retraverse the spectrum per feature.
Q_arr = np.atleast_1d(Q).astype(float).ravel()
if null_params is not None and "method" in null_params:
null_approx = str(null_params["method"])
else:
null_approx = "clt" if kernel.method == "moran" else "liu"
if kernel.method == "moran" and null_approx != "clt":
raise ValueError(
f"Moran's I kernel is indefinite; only null_method='clt' is "
f"supported for the Q-test. Got method={null_approx!r}."
)
def _get_mean_var() -> tuple[float, float]:
if null_params is not None and "mean_Q" in null_params and "var_Q" in null_params:
return float(null_params["mean_Q"]), float(null_params["var_Q"])
# Fall back to compute_null_params so we get H-centered moments
# with the finite-n ratio correction — raw trace(K), 2·trace(K²)
# would inflate the null variance (see compute_null_params docstring).
from quadsv.statistics import compute_null_params
p = compute_null_params(kernel, method="clt")
return float(p["mean_Q"]), float(p["var_Q"])
if null_approx == "clt":
mean_Q, var_Q = _get_mean_var()
sigma = float(np.sqrt(var_Q))
if sigma <= 1e-12:
pvals = np.ones_like(Q_arr)
else:
z_scores = (Q_arr - mean_Q) / sigma
pvals = chi2.sf(z_scores**2, df=1)
elif null_approx == "welch":
if null_params is not None and "scale_g" in null_params and "df_h" in null_params:
g = float(null_params["scale_g"])
h = float(null_params["df_h"])
else:
mean_Q, var_Q = _get_mean_var()
if mean_Q <= 0 or var_Q <= 0:
pvals = np.ones_like(Q_arr)
g = h = None
else:
g = var_Q / (2.0 * mean_Q)
h = 2.0 * mean_Q**2 / var_Q
if g is not None:
pvals = chi2.sf(Q_arr / g, df=h)
elif null_approx == "liu":
from quadsv.statistics import _liu_apply, _liu_prepare, _liu_prepare_from_cumulants
# Dirichlet(1/2) variance correction: pass ``n`` so ``sigma_Q``
# uses ``2·(m·c_2 − c_1²)/(m+2)`` rather than the large-n limit
# ``2·c_2``. Matters on broad-spectrum PSD kernels where
# ``c_1² ≈ m·c_2`` (e.g. CAR on a dense regular grid).
n_kernel = int(kernel.n)
coef = None if null_params is None else null_params.get("liu_coef")
if coef is None and null_params is not None and "cumulants" in null_params:
coef = _liu_prepare_from_cumulants(null_params["cumulants"], n=n_kernel)
if coef is None:
# No cached coef and no user-supplied cumulants — auto-build
# from the kernel's own full spectrum (cheap for FFT: O(n)).
if null_params is not None and null_params.keys() - {"method"}:
raise ValueError(
"null_params with method='liu' must contain either "
"'liu_coef' (preferred) or 'cumulants'. Build via "
"compute_null_params(kernel, method='liu')."
)
evals = kernel.eigenvalues(return_full_layout=True)
if evals.min() < -0.1:
raise ValueError(
"Kernel has significant negative eigenvalues; Liu's method may be invalid."
)
sig_evals = evals[evals > 1e-9]
coef = _liu_prepare(sig_evals, n=n_kernel)
pvals = np.atleast_1d(_liu_apply(Q_arr, coef))
else:
raise ValueError(f"Unknown null approximation method: {null_approx!r}")
# Unwrap to scalar if the caller passed a 2D grid for a single feature.
if np.ndim(Q) == 0:
return Q, float(pvals[0])
return Q, pvals
def _standardize_grid(X: np.ndarray) -> np.ndarray:
"""Z-score standardize a grid tensor along spatial dims (0, 1)."""
m = np.mean(X, axis=(0, 1), keepdims=True)
s = np.std(X, axis=(0, 1), keepdims=True, ddof=1)
Z = np.zeros_like(X)
np.divide(X - m, s, out=Z, where=(s > 1e-12))
return Z
def _spectral_cross_product(
Zx: np.ndarray, Zy: np.ndarray, kernel: FFTKernel, ny: int, nx: int
) -> np.ndarray:
"""Compute sum of conj(Zx_hat) * lambda * Zy_hat in frequency domain."""
if kernel.fft_solver == "fft2":
Zx_hat = scipy.fft.fft2(Zx, axes=(0, 1), workers=kernel.workers)
Zy_hat = scipy.fft.fft2(Zy, axes=(0, 1), workers=kernel.workers)
lam = kernel.eigenvalues().reshape(ny, nx, 1)
spectral_prod = np.real(np.conj(Zx_hat) * lam * Zy_hat)
return np.sum(spectral_prod, axis=(0, 1))
# rfft2 case with symmetry correction
Zx_hat = scipy.fft.rfft2(Zx, axes=(0, 1), workers=kernel.workers)
Zy_hat = scipy.fft.rfft2(Zy, axes=(0, 1), workers=kernel.workers)
lam = kernel.eigenvalues().reshape(ny, nx // 2 + 1, 1)
spectral_prod = np.real(np.conj(Zx_hat) * lam * Zy_hat)
R_sum = 2.0 * np.sum(spectral_prod, axis=(0, 1))
R_sum -= np.sum(spectral_prod[:, 0, :], axis=0)
if nx % 2 == 0:
R_sum -= np.sum(spectral_prod[:, -1, :], axis=0)
return R_sum
def _r_test_fft(
Xn: np.ndarray,
Yn: np.ndarray,
kernel: FFTKernel,
null_params: dict | None = None,
return_pval: bool = True,
is_standardized: bool = False,
) -> float | np.ndarray | tuple[float, float] | tuple[np.ndarray, np.ndarray]:
"""
FFT-accelerated spatial R-test (bivariate) for grid data.
Tests for spatial co-variation between two variables using the specified kernel.
Computes the cross-variance statistic ``R = xᵀ K y`` via FFT-based spectral methods.
Parameters
----------
Xn : np.ndarray
First input tensor. Shape (ny, nx) for single feature or (ny, nx, M) for M features.
Will be automatically reshaped to 3D if 2D.
Yn : np.ndarray
Second input tensor. Must have the same shape as Xn.
kernel : FFTKernel
Pre-constructed FFT kernel object for grid data.
null_params : dict, optional
Pre-computed null parameters from
:func:`quadsv.statistics.compute_null_params`. Only the
``var_R = trace(K²)`` entry is consumed here; when None, it is
computed on the fly from ``kernel.square_trace()``.
return_pval : bool, default True
If True, returns (R, pval) tuple; if False, returns R only.
is_standardized : bool, default False
If True, skips standardization. Otherwise standardizes each variable
independently (mean 0, std 1) across spatial dimensions.
Returns
-------
R : float or np.ndarray
Test statistic (cross-variance). Scalar if input was 2D; array of shape (M,) if 3D.
pval : float or np.ndarray, optional
Two-tailed p-value under null hypothesis (no spatial co-variation).
Based on Normal approximation: :math:`z = R / \\sqrt{\\text{Trace}(K^2)}`.
Only returned if return_pval=True.
Raises
------
ValueError
If ``Xn`` and ``Yn`` shapes don't match, or spatial dimensions don't match kernel.
Notes
-----
Under H₀: x and y are spatially independent.
Under H₁: spatial co-clustering or co-dispersion present.
Computationally: R = z_x^T K z_y where z_x, z_y are standardized data.
Uses FFT via Parseval's theorem: :math:`R = \\frac{1}{N} \\sum_{i,j} \\overline{Z_{x}}_{i,j} \\lambda_{i,j} Z_{y_{i,j}}`
P-value calculation assumes asymptotic Normality with variance estimated from
kernel trace: :math:`\\text{Var}(R) \\approx \\text{Trace}(K^2) / N^2`.
Returns two-tailed probability: :math:`p = 2 P(|Z| > |\\text{z-score}|)`.
Examples
--------
>>> ny, nx = 32, 32
>>> kernel = FFTKernel((ny, nx), method='gaussian', bandwidth=1.0)
>>> x_data = np.random.randn(ny, nx)
>>> y_data = np.random.randn(ny, nx)
>>> R, pval = spatial_r_test(x_data, y_data, kernel)
"""
Xn = np.asarray(Xn).astype(float)
Yn = np.asarray(Yn).astype(float)
if Xn.ndim == 2:
Xn = Xn[..., np.newaxis]
if Yn.ndim == 2:
Yn = Yn[..., np.newaxis]
ny, nx, M = Xn.shape
if Xn.shape != Yn.shape:
raise ValueError(f"Xn and Yn shapes must match, got {Xn.shape} and {Yn.shape}")
if ny != kernel.ny or nx != kernel.nx:
raise ValueError(
f"Data shape ({ny}, {nx}) does not match kernel ({kernel.ny}, {kernel.nx})"
)
# 1. Standardization
if is_standardized:
Zx, Zy = Xn, Yn
else:
Zx = _standardize_grid(Xn)
Zy = _standardize_grid(Yn)
# 2. Compute R = Zx^T K Zy via FFT (Parseval's theorem)
R_sum = _spectral_cross_product(Zx, Zy, kernel, ny, nx)
# Apply Parseval's 1/n normalization
n_pixels = ny * nx
R = R_sum / n_pixels
# Unwrap if M=1
if M == 1 and R.size == 1:
R = R.item()
if not return_pval:
return R
# 3. P-values (Normal Approximation). R is Normal under H₀ with
# variance ``trace((HKH)²)`` — NOT ``trace(K²)``, since both X and Y
# are z-scored before R = Zₓᵀ K Zᵧ is formed. Honor a precomputed
# ``var_R`` if the caller supplied one via ``compute_null_params``.
if null_params is not None and "var_R" in null_params:
var_R = float(null_params["var_R"])
else:
# kernel.square_trace() returns trace((HKH)²) by default (centering=True).
var_R = float(kernel.square_trace())
sigma = np.sqrt(var_R)
if sigma > 1e-12:
z_scores = R / sigma
pval = 2 * norm.sf(np.abs(z_scores))
else:
pval = np.ones_like(R) if isinstance(R, np.ndarray) else 1.0
return R, pval