"""
Transform and utility functions for PNMF.
This module provides:
1. Conditional inference functions (transform_W, transform_F)
2. Factor extraction functions (log_factors, factors, factor_uncertainty, factor_samples)
3. Model accessor functions (get_loadings, get_prior)
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Union, Optional
from tqdm.auto import tqdm
from .elbo import compute_elbo, compute_log_likelihood_terms
from .priors import GaussianPrior
# =============================================================================
# Factor Extraction Functions
# =============================================================================
def _get_spatial_qF(model, coordinates=None, groups=None):
"""
Get the variational distribution qF from a spatial model.
For spatial models, runs the GP forward pass with the given or stored
coordinates to produce the predictive distribution qF.
Args:
model: Fitted PNMF model with spatial=True.
coordinates: Optional coordinates override. If None, uses stored training coords.
groups: Optional groups override. If None, uses stored training groups.
Returns:
qF: Normal distribution with .mean and .scale of shape (L, N).
"""
coords = coordinates if coordinates is not None else model._coordinates
grps = groups if groups is not None else model._groups
if coords is None:
raise ValueError("No coordinates available. Pass coordinates or fit with spatial=True.")
# Determine device from the prior parameters
device = next(model._prior.parameters()).device
# Convert numpy arrays to tensors if needed
if isinstance(coords, np.ndarray):
coords = torch.from_numpy(coords.astype(np.float32)).to(device)
elif coords.device != device:
coords = coords.to(device)
if grps is not None:
if isinstance(grps, np.ndarray):
grps = torch.from_numpy(grps.astype(np.int64)).to(device)
elif grps.device != device:
grps = grps.to(device)
with torch.no_grad():
# For LCGP: set KNN indices before calling forward()
if hasattr(model, 'local') and model.local:
from gpzoo.knn_utilities import calculate_knn
neighbors = getattr(model, 'neighbors', 'knn')
grps_for_knn = grps if (hasattr(model, 'multigroup') and model.multigroup) else None
groupsZ_for_knn = model._groups if grps_for_knn is not None else None
knn_idx = calculate_knn(
model._prior, coords, strategy=neighbors,
multigroup=grps_for_knn is not None,
groupsX=grps_for_knn, groupsZ=groupsZ_for_knn,
)[:, :-1]
model._prior.knn_idx = knn_idx
if grps is not None:
qF, _, _ = model._prior(X=coords, groupsX=grps)
else:
qF, _, _ = model._prior(X=coords)
return qF
[docs]
def log_factors(
model,
coordinates=None,
groups=None,
return_tensor: bool = False,
) -> Union[np.ndarray, torch.Tensor]:
"""
Get log-space latent factors (μ from q(F) = Normal(μ, σ²)).
Parameters
----------
model : PNMF
A fitted PNMF model.
coordinates : array-like, optional
Spatial coordinates. For spatial models, uses stored training
coordinates if not provided.
groups : array-like, optional
Group assignments. For spatial models, uses stored training
groups if not provided.
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
F_log : ndarray or Tensor of shape (n_samples, n_components)
Latent factors in log-space (μ from the variational distribution).
Examples
--------
>>> from PNMF import PNMF
>>> from PNMF.transforms import log_factors
>>> model = PNMF(n_components=5).fit(X)
>>> F_log = log_factors(model) # Shape: (n_samples, n_components)
"""
if model._prior is None:
raise ValueError("Model has not been fitted yet.")
if getattr(model, 'spatial', False):
qF = _get_spatial_qF(model, coordinates, groups)
mu = qF.mean.detach()
elif model._prior.use_natural_gradients:
# Natural parameterization: convert to mean
theta1 = model._prior.theta1.detach()
theta2 = model._prior.theta2.detach()
s_sq = -1 / (2 * theta2)
mu = theta1 * s_sq
else:
mu = model._prior.mean.detach()
# Transpose from (L, N) to (N, L) for sklearn-style
F_log = mu.T
if return_tensor:
return F_log
return F_log.cpu().numpy()
def get_factors(
model,
use_mgf: bool = False,
coordinates=None,
groups=None,
return_tensor: bool = False
) -> Union[np.ndarray, torch.Tensor]:
"""
Get exp-space latent factors.
By default, computes the expected value E[exp(F)] using the moment-generating
function of the Gaussian: E[exp(F)] = exp(μ + σ²/2).
Parameters
----------
model : PNMF
A fitted PNMF model.
use_mgf : bool, default=False
If True, compute E[exp(F)] = exp(μ + σ²/2) using the MGF.
If False, compute exp(μ) directly.
coordinates : array-like, optional
Spatial coordinates (for spatial models).
groups : array-like, optional
Group assignments (for spatial models).
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
F_exp : ndarray or Tensor of shape (n_samples, n_components)
Latent factors in exp-space.
Notes
-----
The moment-generating function (MGF) of a Gaussian gives:
E[exp(F)] = exp(μ + σ²/2)
This is the true expected value under the variational distribution q(F).
Using use_mgf=False gives exp(μ), which is a biased estimate.
Examples
--------
>>> from PNMF import PNMF
>>> from PNMF.transforms import factors
>>> model = PNMF(n_components=5).fit(X)
>>> F_exp = factors(model) # E[exp(F)], shape: (n_samples, n_components)
>>> F_exp_biased = factors(model, use_mgf=False) # exp(μ)
"""
if model._prior is None:
raise ValueError("Model has not been fitted yet.")
if getattr(model, 'spatial', False):
qF = _get_spatial_qF(model, coordinates, groups)
mu = qF.mean.detach()
scale_sq = (qF.scale.detach()) ** 2
elif model._prior.use_natural_gradients:
theta1 = model._prior.theta1.detach()
theta2 = model._prior.theta2.detach()
s_sq = -1 / (2 * theta2)
mu = theta1 * s_sq
scale_sq = s_sq
else:
mu = model._prior.mean.detach()
scale = model._prior.scale.data.detach().clamp(min=1e-8)
scale_sq = scale ** 2
if use_mgf:
# E[exp(F)] = exp(μ + σ²/2)
F_exp = torch.exp(mu + scale_sq / 2)
else:
# exp(μ) - biased estimate
F_exp = torch.exp(mu)
# Transpose from (L, N) to (N, L) for sklearn-style
F_exp = F_exp.T
if return_tensor:
return F_exp
return F_exp.cpu().numpy()
[docs]
def factor_uncertainty(
model,
return_variance: bool = False,
coordinates=None,
groups=None,
return_tensor: bool = False
) -> Union[np.ndarray, torch.Tensor]:
"""
Get uncertainty in latent factors (σ or σ² from q(F) = Normal(μ, σ²)).
Parameters
----------
model : PNMF
A fitted PNMF model.
return_variance : bool, default=False
If True, return variance (σ²). If False, return standard deviation (σ).
coordinates : array-like, optional
Spatial coordinates (for spatial models).
groups : array-like, optional
Group assignments (for spatial models).
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
uncertainty : ndarray or Tensor of shape (n_samples, n_components)
Standard deviation (σ) or variance (σ²) of the variational distribution.
Examples
--------
>>> from PNMF import PNMF
>>> from PNMF.transforms import factor_uncertainty
>>> model = PNMF(n_components=5).fit(X)
>>> F_std = factor_uncertainty(model) # σ, shape: (n_samples, n_components)
>>> F_var = factor_uncertainty(model, return_variance=True) # σ²
"""
if model._prior is None:
raise ValueError("Model has not been fitted yet.")
if getattr(model, 'spatial', False):
qF = _get_spatial_qF(model, coordinates, groups)
scale = qF.scale.detach()
if return_variance:
result = scale ** 2
else:
result = scale
elif model._prior.use_natural_gradients:
theta2 = model._prior.theta2.detach()
s_sq = -1 / (2 * theta2)
if return_variance:
result = s_sq
else:
result = torch.sqrt(s_sq)
else:
scale = model._prior.scale.data.detach().clamp(min=1e-8)
if return_variance:
result = scale ** 2
else:
result = scale
# Transpose from (L, N) to (N, L) for sklearn-style
result = result.T
if return_tensor:
return result
return result.cpu().numpy()
[docs]
def factor_samples(
model,
n_samples: int = 100,
return_exp: bool = False,
coordinates=None,
groups=None,
return_tensor: bool = False
) -> Union[np.ndarray, torch.Tensor]:
"""
Sample latent factors from the variational posterior q(F).
Parameters
----------
model : PNMF
A fitted PNMF model.
n_samples : int, default=100
Number of samples to draw.
return_exp : bool, default=False
If True, return exp(F) samples. If False, return F samples.
coordinates : array-like, optional
Spatial coordinates (for spatial models).
groups : array-like, optional
Group assignments (for spatial models).
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
samples : ndarray or Tensor of shape (n_samples, n_data_samples, n_components)
Samples from the variational posterior.
Notes
-----
Uses the reparameterization trick for sampling.
Examples
--------
>>> from PNMF import PNMF
>>> from PNMF.transforms import factor_samples
>>> model = PNMF(n_components=5).fit(X)
>>> samples = factor_samples(model, n_samples=100) # (100, n_data, 5)
>>> exp_samples = factor_samples(model, n_samples=50, return_exp=True)
"""
if model._prior is None:
raise ValueError("Model has not been fitted yet.")
# Get the variational distribution
if getattr(model, 'spatial', False):
qF = _get_spatial_qF(model, coordinates, groups)
else:
qF, _ = model._prior()
# Sample using reparameterization trick
# qF.rsample returns shape (n_samples, L, N)
samples = qF.rsample((n_samples,))
if return_exp:
samples = torch.exp(samples)
# Transpose from (n_samples, L, N) to (n_samples, N, L) for sklearn-style
samples = samples.permute(0, 2, 1)
if return_tensor:
return samples.detach()
return samples.detach().cpu().numpy()
# =============================================================================
# Model Accessor Functions
# =============================================================================
[docs]
def get_loadings(model, return_tensor: bool = False) -> Union[np.ndarray, torch.Tensor]:
"""
Get the loadings matrix W from a fitted PNMF model.
Parameters
----------
model : PNMF
A fitted PNMF model.
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
W : ndarray or Tensor of shape (n_features, n_components)
The loadings matrix W.
Notes
-----
This is equivalent to model.components_.T but provides a consistent
interface and works with both fitted and unfitted (but initialized) models.
Examples
--------
>>> from PNMF import PNMF
>>> from PNMF.transforms import get_loadings
>>> model = PNMF(n_components=5).fit(X)
>>> W = get_loadings(model) # Shape: (n_features, n_components)
"""
if model._model is None:
raise ValueError("Model has not been fitted yet.")
W = model._model.W.data.detach()
if return_tensor:
return W
return W.cpu().numpy()
[docs]
def get_prior(model) -> GaussianPrior:
"""
Get the GaussianPrior object from a fitted PNMF model.
Parameters
----------
model : PNMF
A fitted PNMF model.
Returns
-------
prior : GaussianPrior
The GaussianPrior object containing the variational distribution.
Notes
-----
This provides access to the full prior for advanced users who want
to directly manipulate or inspect the variational distribution.
Examples
--------
>>> from PNMF import PNMF
>>> from PNMF.transforms import get_prior
>>> model = PNMF(n_components=5).fit(X)
>>> prior = get_prior(model)
>>> qF, pF = prior() # Get distributions directly
"""
if model._prior is None:
raise ValueError("Model has not been fitted yet.")
return model._prior
# =============================================================================
# Conditional Inference Functions
# =============================================================================
# =============================================================================
# Utility Functions for Working with Priors
# =============================================================================
[docs]
def log_factors_from_prior(
prior: GaussianPrior,
return_tensor: bool = False
) -> Union[np.ndarray, torch.Tensor]:
"""
Get log-space latent factors from a GaussianPrior object.
Parameters
----------
prior : GaussianPrior
A GaussianPrior object.
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
F_log : ndarray or Tensor of shape (n_samples, n_components)
Latent factors in log-space.
Examples
--------
>>> from PNMF.transforms import transform_F, log_factors_from_prior
>>> prior = transform_F(X_new, W, return_prior=True)
>>> F_log = log_factors_from_prior(prior)
"""
if prior.use_natural_gradients:
theta1 = prior.theta1.detach()
theta2 = prior.theta2.detach()
s_sq = -1 / (2 * theta2)
mu = theta1 * s_sq
else:
mu = prior.mean.detach()
F_log = mu.T
if return_tensor:
return F_log
return F_log.cpu().numpy()
[docs]
def factors_from_prior(
prior: GaussianPrior,
use_mgf: bool = True,
return_tensor: bool = False
) -> Union[np.ndarray, torch.Tensor]:
"""
Get exp-space latent factors from a GaussianPrior object.
Parameters
----------
prior : GaussianPrior
A GaussianPrior object.
use_mgf : bool, default=True
If True, compute E[exp(F)] = exp(μ + σ²/2).
If False, compute exp(μ) directly.
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
F_exp : ndarray or Tensor of shape (n_samples, n_components)
Latent factors in exp-space.
Examples
--------
>>> from PNMF.transforms import transform_F, factors_from_prior
>>> prior = transform_F(X_new, W, return_prior=True)
>>> F_exp = factors_from_prior(prior)
"""
if prior.use_natural_gradients:
theta1 = prior.theta1.detach()
theta2 = prior.theta2.detach()
s_sq = -1 / (2 * theta2)
mu = theta1 * s_sq
scale_sq = s_sq
else:
mu = prior.mean.detach()
scale = prior.scale.data.detach().clamp(min=1e-8)
scale_sq = scale ** 2
if use_mgf:
F_exp = torch.exp(mu + scale_sq / 2)
else:
F_exp = torch.exp(mu)
F_exp = F_exp.T
if return_tensor:
return F_exp
return F_exp.cpu().numpy()
[docs]
def uncertainty_from_prior(
prior: GaussianPrior,
return_variance: bool = False,
return_tensor: bool = False
) -> Union[np.ndarray, torch.Tensor]:
"""
Get uncertainty from a GaussianPrior object.
Parameters
----------
prior : GaussianPrior
A GaussianPrior object.
return_variance : bool, default=False
If True, return variance (σ²). If False, return standard deviation (σ).
return_tensor : bool, default=False
If True, return a torch.Tensor instead of numpy array.
Returns
-------
uncertainty : ndarray or Tensor of shape (n_samples, n_components)
Standard deviation or variance.
Examples
--------
>>> from PNMF.transforms import transform_F, uncertainty_from_prior
>>> prior = transform_F(X_new, W, return_prior=True)
>>> F_std = uncertainty_from_prior(prior)
"""
if prior.use_natural_gradients:
theta2 = prior.theta2.detach()
s_sq = -1 / (2 * theta2)
if return_variance:
result = s_sq
else:
result = torch.sqrt(s_sq)
else:
scale = prior.scale.data.detach().clamp(min=1e-8)
if return_variance:
result = scale ** 2
else:
result = scale
result = result.T
if return_tensor:
return result
return result.cpu().numpy()