Source code for PNMF.custom_modules

"""
Utility classes for constrained parameters.

Adapted from GPzoo: https://github.com/luisdiaz1997/GPzoo
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import abstractmethod
from typing import Union, Tuple, Optional


class NaturalToMuS(torch.autograd.Function):
    """
    Custom autograd function for converting natural parameters to (μ, s).

    Natural parameterization for Gaussian variational distribution:
        θ₁ = μ/s²  (natural parameter for mean)
        θ₂ = -1/(2s²)  (natural parameter for precision)

    Forward pass: converts (θ₁, θ₂) → (μ, s)
    Backward pass: returns gradients w.r.t. expectation parameters (η₁, η₂)
                   where η₁ = μ and η₂ = s² + μ²

    This enables natural gradient descent for variational inference.

    Example:
        >>> theta1 = torch.randn(10, 50, requires_grad=True)
        >>> theta2 = -0.5 * torch.ones(10, 50, requires_grad=True)
        >>> mu, s = NaturalToMuS.apply(theta1, theta2)
    """

    @staticmethod
    def forward(ctx, theta1, theta2, jitter=1e-6):
        """
        Convert natural parameters to (μ, s).

        Natural parameters:
            θ₁ = μ/s²
            θ₂ = -1/(2s²)

        Args:
            ctx: Context object for backward pass
            theta1: Natural parameter θ₁ of shape (L, N)
            theta2: Natural parameter θ₂ of shape (L, N)
            jitter: Small value for numerical stability

        Returns:
            mu: Mean parameter of shape (L, N)
            s: Standard deviation of shape (L, N)
        """
        # s² = -1/(2θ₂)
        s_squared = -1.0 / (2.0 * theta2 + jitter)
        s = torch.sqrt(s_squared.clamp(min=jitter))

        # μ = s² * θ₁
        mu = s_squared * theta1

        ctx.save_for_backward(mu, s)
        return mu, s

    @staticmethod
    def backward(ctx, dout_dmu, dout_ds):
        """
        Compute gradients w.r.t. expectation parameters (η₁, η₂).

        The natural gradients are computed as:
            ∂L/∂η₁ = ∂L/∂μ - 2μ·∂L/∂(s²)
            ∂L/∂η₂ = ∂L/∂(s²)

        where ∂L/∂(s²) = (∂L/∂s) * (∂s/∂s²) = (∂L/∂s) / (2s)

        Args:
            ctx: Context object with saved tensors
            dout_dmu: Gradient of loss w.r.t. μ
            dout_ds: Gradient of loss w.r.t. s

        Returns:
            dout_deta1: Gradient w.r.t. η₁
            dout_deta2: Gradient w.r.t. η₂
            None: For jitter argument
        """
        mu, s = ctx.saved_tensors

        # Convert gradient w.r.t. s to gradient w.r.t. s²
        # ∂s/∂s² = 1/(2s)
        dout_ds_squared = dout_ds / (2.0 * s + 1e-8)

        # Natural gradients (gradients w.r.t. expectation params)
        # η₁ = μ
        # η₂ = s² + μ²
        dout_deta1 = dout_dmu - 2.0 * mu * dout_ds_squared
        dout_deta2 = dout_ds_squared

        return dout_deta1, dout_deta2, None


[docs] class ConstrainedParameter(nn.Module): """ Base class for parameters with constraints. Subclasses implement _to_constrained() and _to_unconstrained() to define the bijective mapping between raw (unconstrained) and constrained spaces. The .data property returns the constrained value, and the underlying unconstrained parameter is stored in ._raw (an nn.Parameter). Args: shape: Shape of the constrained parameter. mode: Constraint mode (subclass-specific). """
[docs] def __init__(self, shape: Union[int, Tuple[int, ...]], mode: str): super().__init__() self.mode = mode self._shape = (shape,) if isinstance(shape, int) else tuple(shape)
# Subclass should set self._raw in its __init__ after calling super().__init__ @property def shape(self) -> Tuple[int, ...]: """Shape of the constrained parameter.""" return self._shape @abstractmethod def _to_constrained(self, raw: torch.Tensor) -> torch.Tensor: """Convert raw parameter to constrained space.""" pass @abstractmethod def _to_unconstrained(self, constrained: torch.Tensor) -> torch.Tensor: """Convert constrained parameter to raw space.""" pass @abstractmethod def _init_raw(self) -> torch.Tensor: """Initialize raw parameter.""" pass # ==================== Common interface ==================== @property def data(self) -> torch.Tensor: """Returns the constrained value.""" return self._to_constrained(self._raw) @data.setter def data(self, target: torch.Tensor): """Sets parameter from a target constrained tensor.""" self._raw.data = self._to_unconstrained(target) @property def requires_grad(self) -> bool: """Returns requires_grad status of the underlying parameter.""" return self._raw.requires_grad @requires_grad.setter def requires_grad(self, value: bool): """Sets requires_grad on the underlying parameter.""" self._raw.requires_grad = value @property def device(self) -> torch.device: """Returns the device of the underlying parameter.""" return self._raw.device @property def dtype(self) -> torch.dtype: """Returns the dtype of the underlying parameter.""" return self._raw.dtype @property def raw(self) -> nn.Parameter: """Direct access to the underlying unconstrained parameter.""" return self._raw
[docs] def freeze(self): """Freeze the parameter (disable gradient computation).""" self._raw.requires_grad = False
[docs] def unfreeze(self): """Unfreeze the parameter (enable gradient computation).""" self._raw.requires_grad = True
[docs] def project(self): """Project parameters to satisfy constraints. Override for projected modes.""" pass
# ==================== Tensor-like interface ==================== def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} def unwrap(x): return x.data if isinstance(x, ConstrainedParameter) else x new_args = [unwrap(arg) for arg in args] new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} return func(*new_args, **new_kwargs)
[docs] def __getitem__(self, idx): """Allow subscripting to index into the constrained value.""" return self.data[idx]
def __len__(self): return self._shape[0] if self._shape else 1
[docs] def dim(self): return len(self._shape)
[docs] def size(self, dim=None): if dim is None: return torch.Size(self._shape) return self._shape[dim]
[docs] def numel(self): return self._raw.numel()
[docs] def detach(self): """Returns a detached copy of the constrained value.""" return self.data.detach()
[docs] def clone(self): """Returns a cloned copy of the constrained value.""" return self.data.clone()
[docs] def cpu(self): """Move to CPU.""" self._raw.data = self._raw.data.cpu() return self
[docs] def cuda(self, device=None): """Move to CUDA.""" self._raw.data = self._raw.data.cuda(device) return self
[docs] def to(self, *args, **kwargs): """Move to device/dtype.""" self._raw.data = self._raw.data.to(*args, **kwargs) return self
[docs] def float(self): """Convert to float32.""" self._raw.data = self._raw.data.float() return self
[docs] def double(self): """Convert to float64.""" self._raw.data = self._raw.data.double() return self
[docs] def half(self): """Convert to float16.""" self._raw.data = self._raw.data.half() return self
[docs] def contiguous(self): """Returns a contiguous copy of the constrained value.""" return self.data.contiguous()
[docs] def numpy(self): """Returns numpy array of the constrained value.""" return self.data.detach().cpu().numpy()
@property def grad(self): """Returns the gradient of the underlying parameter.""" return self._raw.grad @property def is_cuda(self): """Check if on CUDA.""" return self._raw.is_cuda @property def is_leaf(self): """Check if leaf tensor.""" return self._raw.is_leaf
[docs] class PositiveParameter(ConstrainedParameter): """ Parameter constrained to be positive. Supports any shape - batching works automatically. Args: shape: int or tuple for the parameter shape. mode: 'softplus', 'exp', 'projected', or 'multiplicative' for ensuring positivity. elbo_mode: ELBO computation mode for multiplicative updates. Only used when mode='multiplicative'. One of 'lower-bound', 'expanded', or 'simple'. The mode determines how positivity is enforced: - 'softplus': Uses softplus(x) = log(1 + exp(x)) transformation - 'exp': Uses exp(x) transformation - 'projected': Uses projected gradient descent (clamps to >= 0 after each step) - 'multiplicative': Uses multiplicative updates (no gradient-based optimization). Requires calling multiplicative_update() manually after computing update terms. For multiplicative mode, the update rule is: W = W * numerator / denominator The numerator and denominator depend on the elbo_mode: - 'lower-bound': Fully analytic using Jensen's bound - 'expanded': Hybrid MC + analytic expectation - 'simple': Full Monte Carlo """
[docs] def __init__( self, shape: Union[int, Tuple[int, ...]], mode: str = 'softplus', elbo_mode: str = 'expanded' ): if mode not in ['softplus', 'exp', 'projected', 'multiplicative']: raise ValueError( f"Unknown mode: {mode}. Choose 'softplus', 'exp', 'projected', or 'multiplicative'" ) if mode == 'multiplicative' and elbo_mode not in ['lower-bound', 'expanded', 'simple']: raise ValueError( f"Unknown elbo_mode: {elbo_mode}. Choose 'lower-bound', 'expanded', or 'simple'" ) super().__init__(shape, mode) self.elbo_mode = elbo_mode # For multiplicative mode, we don't need gradients (updated via multiplicative_update) requires_grad = (mode != 'multiplicative') self._raw = nn.Parameter(self._init_raw(), requires_grad=requires_grad)
def _init_raw(self) -> torch.Tensor: if self.mode in ['projected', 'multiplicative']: return torch.rand(self._shape) else: return torch.randn(self._shape) def _to_constrained(self, raw: torch.Tensor) -> torch.Tensor: if self.mode == 'softplus': return F.softplus(raw) elif self.mode == 'exp': return torch.exp(raw) else: # projected or multiplicative return raw def _to_unconstrained(self, constrained: torch.Tensor) -> torch.Tensor: if self.mode == 'softplus': # Inverse softplus: log(exp(x) - 1) return torch.log(torch.exp(constrained) - 1) elif self.mode == 'exp': return torch.log(constrained) else: # projected or multiplicative return constrained.clamp(min=0.0)
[docs] def project(self): """ Project parameters to satisfy constraints. For 'projected' mode, this clamps values to be >= 0. Call this after optimizer.step() when using projected gradients. """ if self.mode == 'projected': with torch.no_grad(): self._raw.data.clamp_(min=0.0)
[docs] def multiplicative_update( self, X: torch.Tensor, terms: dict, idy: Optional[torch.Tensor] = None, eps: float = 1e-8 ): """ Perform multiplicative update for W using precomputed terms. Reuses all tensors already computed by compute_log_likelihood_terms(): - exp_mu, exp_mu_sigma (always available) - rate_mean (lower-bound: W @ exp(μ), already computed) - exp_F_samples (MC modes: collected during the loop) For MC modes, rate_e = W @ exp_F_e is computed per-sample in a loop under no_grad, avoiding the (E, D, N) batch matmul entirely. Update rule: W_jl <- W_jl * numerator_jl / denominator_jl Args: X: Data tensor of shape (D, N) or (D_batch, N). terms: dict from compute_log_likelihood_terms(return_samples=True). Required keys depend on elbo_mode: - All modes: 'exp_mu', 'exp_mu_sigma' - 'lower-bound': 'rate_mean' - 'expanded'/'simple': 'exp_F_samples' idy: Optional feature indices for batched updates. Shape: (D_batch,). eps: Small constant for numerical stability. """ if self.mode != 'multiplicative': raise ValueError( f"multiplicative_update() requires mode='multiplicative', got mode='{self.mode}'" ) with torch.no_grad(): # Get W (full or batched) W_full = self._raw # (D, L) W = W_full[idy] if idy is not None else W_full # (D_batch, L) or (D, L) if self.elbo_mode == 'lower-bound': exp_mu = terms['exp_mu'] # (L, N) rate = terms['rate_mean'] # (D_batch, N) ratio = X / (rate + eps) # (D_batch, N) numerator = torch.matmul(ratio, exp_mu.T) # (D_batch, L) denominator = terms['exp_mu_sigma'].sum(dim=1) # (L,) else: # MC modes — loop over samples to avoid (E, D_batch, N) tensors exp_F = terms['exp_F_samples'] # (E, L, N) E_samples = exp_F.shape[0] D_batch = W.shape[0] L = W.shape[1] numerator = torch.zeros(D_batch, L, dtype=X.dtype, device=X.device) denominator_acc = torch.zeros(L, dtype=X.dtype, device=X.device) if self.elbo_mode == 'simple' else None for e in range(E_samples): exp_F_e = exp_F[e] # (L, N) rate_e = torch.matmul(W, exp_F_e) # (D_batch, N) ratio_e = X / (rate_e + eps) # (D_batch, N) numerator += torch.matmul(ratio_e, exp_F_e.T) # (D_batch, L) if denominator_acc is not None: denominator_acc += exp_F_e.sum(dim=1) # (L,) if self.elbo_mode == 'expanded': numerator = numerator / E_samples # (D_batch, L) denominator = terms['exp_mu_sigma'].sum(dim=1) # (L,) — analytic else: # simple denominator = denominator_acc # (L,) — MC # Apply multiplicative update: W <- W * num / denom updated_W = W * numerator / (denominator + eps) updated_W.clamp_(min=eps) if idy is not None: self._raw.data[idy] = updated_W else: self._raw.data = updated_W
def __repr__(self): grad_str = ", requires_grad=True" if self._raw.requires_grad else "" elbo_str = f", elbo_mode='{self.elbo_mode}'" if self.mode == 'multiplicative' else "" return f"PositiveParameter(shape={self._shape}, mode='{self.mode}'{elbo_str}{grad_str})"