Source code for mmer.core.mixed_effect

import warnings
import numpy as np
from sklearn.base import RegressorMixin
from sklearn.model_selection import GroupShuffleSplit
from tqdm import tqdm
from .corrections import VarianceCorrection
from .solver import build_solver
from .convergence import ConvergenceMonitor
from .inference import aggregate_random_effects
from .terms import RandomEffectTerm, ResidualTerm, RealizedRandomEffect, RealizedResidual


[docs] class MixedEffectRegressor: """ Multivariate Mixed Effects Regression (MMER) using Expectation-Maximization. Fits mixed model with multiple responses, supporting arbitrary grouping factors and linear random slopes. Solves for random effects and residual covariances using EM algorithm with stochastic log-determinant estimation. Parameters ---------- fixed_effects_model : RegressorMixin Base regressor for fixed effects (must support multi-output). max_iter : int, default=20 Maximum number of EM iterations. tol : float, default=1e-6 Convergence tolerance on log-likelihood relative change. patience : int, default=3 Number of iterations to wait for likelihood improvement before early stopping. Setting to a high value effectively disables early stopping and relies solely on `tol`. correction_method : str, default='bste' Method for variance correction in M-step: - 'bste': block stochastic trace estimation - 'de': deterministic estimation slq_steps : int, default=30 Number of Lanczos steps for Stochastic Lanczos Quadrature (log-det estimation). A range of 30-50 is typically sufficient. Higher values yield slightly more accurate estimates but increase computation time and risk numerical instability. n_probes : int, default=60 Number of random probes used for both SLQ log-determinant estimation and stochastic variance correction. A fixed target around 50-60 is usually optimal independent of matrix dimension for O(1/sqrt(p)) error convergence. preconditioner : bool, default=True Whether to use residual-based preconditioner for CG solver. n_jobs : int, default=-1 Number of parallel jobs for SLQ and trace estimation (-1 uses all cores). Setting to number of outputs (`m`) is recommended for optimal performance. backend : str, default='threading' Joblib parallel backend ('threading', 'loky'). Setting to 'threading' is highly recommended as almsot always the solver is Woodbury-based. Setting to 'loky' can be used rarely for cg-based solvers under very specific conditions of large `m` and small `n` where woodbury is not beneficial. Attributes ---------- fe_model : RegressorMixin Fitted fixed effects model. random_effect_terms : tuple of RandomEffectTerm Fitted random effect terms containing covariance matrices. residual_term : ResidualTerm Fitted residual term containing residual covariance matrix. log_likelihood : list of float Log-likelihood values across EM iterations. n : int Number of observations. m : int Number of output dimensions. k : int Number of grouping factors. Examples -------- >>> from sklearn.linear_model import Ridge >>> model = MixedEffectRegressor(fixed_effects_model=Ridge()) >>> results = model.fit(X, y, groups, random_slopes=([0, 1], None)) >>> predictions = model.predict(X_new) """ def __init__( self, fixed_effects_model: RegressorMixin, max_iter: int = 30, tol: float = 1e-6, patience: int = 3, correction_method: str = "bste", slq_steps: int = 30, n_probes: int = 60, preconditioner: bool = True, cg_maxiter: int = 1000, n_jobs: int = -1, backend: str = "threading", ): self.fe_model = fixed_effects_model self.max_iter = max_iter self.tol = tol self.patience = max(1, patience) self.correction_method = correction_method self.slq_steps = slq_steps self.n_probes = n_probes self.preconditioner = preconditioner self.cg_maxiter = cg_maxiter self.n_jobs = n_jobs self.backend = backend self.convergence_monitor = ConvergenceMonitor(tol=tol, patience=patience) self.variance_corrector = VarianceCorrection( method=correction_method, cg_maxiter=cg_maxiter, n_jobs=n_jobs, backend=backend, ) # State: Terms self.random_effect_terms: list[RandomEffectTerm] | None = None self.residual_term: ResidualTerm | None = None self.has_validation = False self.train_idx: np.ndarray | None = None self.val_idx: np.ndarray | None = None self.force_iterative = False def _prepare_terms( self, y: np.ndarray, groups: np.ndarray, random_slopes: tuple[list[int] | None, ...] | None, ): """ Initialize state RandomEffect and Residual Terms if not present. """ self.n, self.m = y.shape # number of sample and outputs self.k = groups.shape[1] # number of groups # 1. Initialize Random Structure Config if random_slopes is None: config_random_slopes = tuple([None] * self.k) elif len(random_slopes) != self.k: raise ValueError( f"Length of random_slopes ({len(random_slopes)}) must match number of groups ({self.k})." ) else: config_random_slopes = random_slopes # 2. Create Terms self.random_effect_terms = [] for i, slope_cols in enumerate(config_random_slopes): term = RandomEffectTerm(group_id=i, covariates_id=slope_cols, m=self.m) self.random_effect_terms.append(term) self.residual_term = ResidualTerm(m=self.m) self.random_slopes = config_random_slopes def _realize_objects(self, X: np.ndarray, groups: np.ndarray) -> tuple: """ Factory method to create realized random effects and residual term. Parameters ---------- X : np.ndarray Covariates, shape (n, p). groups : np.ndarray Grouping factors, shape (n, k). Returns ------- realized_effects : tuple of RealizedRandomEffect Realized random effects. realized_residual : RealizedResidual Realized residual term. """ n = X.shape[0] realized_effects = tuple( RealizedRandomEffect(term, X, groups) for term in self.random_effect_terms ) realized_residual = RealizedResidual(self.residual_term, n) return realized_effects, realized_residual
[docs] def prepare_data( self, X: np.ndarray, y: np.ndarray, groups: np.ndarray, validation_split: float = 0.0, validation_group: int = 0, ): """ Prepare data for EM algorithm by creating realized objects. Generates transient realized random effects and residual for the current dataset. Optionally splits data into training and validation sets based on group membership. Parameters ---------- X : np.ndarray Covariates, shape (n, p). y : np.ndarray Multi-output targets, shape (n, m). groups : np.ndarray Grouping factors, shape (n, k). validation_split : float, default=0.0 Fraction of groups to use for fixed-effects validation (0.0 means no validation). Setting to a non-zero value means fixed effects can accept validation data. validation_group : int, default=0 Column index in `groups` to use for group-wise validation splitting. Returns ------- marginal_residual : np.ndarray Initial marginal residual, raveled shape (m*n,). realized_effects : tuple of RealizedRandomEffect Realized random effect objects. realized_residual : RealizedResidual Realized residual term. """ # Setup Validation Split if validation_split > 0: main_group = groups[:, validation_group] gss = GroupShuffleSplit( n_splits=1, test_size=validation_split, random_state=42 ) self.train_idx, self.val_idx = next(gss.split(X, y, groups=main_group)) self.has_validation = True else: self.train_idx = None self.val_idx = None self.has_validation = False # Instantiate Realized Objects (Transient) realized_effects, realized_residual = self._realize_objects(X, groups) # Initial Marginal Residual marginal_residual = self._compute_marginal_residual(X, y, 0.0) return marginal_residual, realized_effects, realized_residual
[docs] def fit( self, X: np.ndarray, y: np.ndarray, groups: np.ndarray, random_slopes: tuple[list[int] | None, ...] | None = None, validation_split: float = 0.0, validation_group: int = 0, ): """ Fit the MMER model using the EM algorithm. Parameters ---------- X : np.ndarray Covariates, shape (n, p) where n is number of observations and p is number of features. y : np.ndarray Multi-output targets, shape (n, m) where m is number of outputs. groups : np.ndarray Grouping factors, shape (n, k) where k is number of grouping factors. Each column represents a different grouping structure. random_slopes : tuple of list of int, optional Tuple of lists specifying random slopes for each grouping factor. Each list contains column indices in X for random slopes corresponding to that group. None or empty list implies random intercept only for that group. If None, all groups get random intercepts only. validation_split : float, default=0.0 Fraction of groups to use for fixed-effects validation (early stopping). Must be between 0.0 and 1.0. Set to 0.0 to disable validation. Setting to a non-zero value means fixed effects can accept validation data. validation_group : int, default=0 Column index in `groups` to use for group-wise validation splitting. Returns ------- MixedEffectResults Fitted result object containing covariance estimates and diagnostics. Examples -------- >>> # Fit model with random intercepts only >>> results = model.fit(X, y, groups) >>> # Fit with random slopes on features 0 and 1 for first group >>> results = model.fit(X, y, groups, random_slopes=([0, 1], None)) >>> # Fit with validation split >>> results = model.fit(X, y, groups, validation_split=0.2) """ # Initialize terms if new training if self.random_effect_terms is None: self._prepare_terms(y, groups, random_slopes) # Reset convergence monitor for new fit self.convergence_monitor.reset() marginal_residual, realized_effects, realized_residual = self.prepare_data( X, y, groups, validation_split, validation_group ) pbar = tqdm( range(1, self.max_iter + 1), desc="Running MMER Framework | Fitting Model ...", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} {elapsed}", ) for iteration in pbar: marginal_residual = self._run_em_iteration( X, y, marginal_residual, realized_effects, realized_residual, iteration ) if self.convergence_monitor.is_converged: if self.convergence_monitor.is_early_stopped: if np.isinf(self.convergence_monitor.log_likelihood[-1]): pbar.set_description("Finished: numerical limits reached!") else: pbar.set_description("Finished: no further improvement!") else: pbar.set_description("Converged: tolerance reached!") break from .mixed_result import MixedEffectResults self.convergence_monitor.restore_best_state(self) return MixedEffectResults(self)
def _run_em_iteration( self, X, y, marginal_residual, realized_effects, realized_residual, iteration: int = 0, ): """ Run one EM iteration. """ total_random_effect, mu, solver = self._e_step( marginal_residual, realized_effects, realized_residual ) if self.convergence_monitor.is_converged: return marginal_residual try: marginal_residual = self._compute_marginal_residual( X, y, total_random_effect.reshape((self.m, self.n)).T ) self._m_step( marginal_residual, total_random_effect, mu, realized_effects, realized_residual, solver, iteration, ) except (np.linalg.LinAlgError, RuntimeError, ValueError): warnings.warn( "Numerical instability encountered during M-step. Reverting to the best valid state.", RuntimeWarning, stacklevel=2, ) self.convergence_monitor.update(-np.inf, self) return marginal_residual def _e_step(self, marginal_residual, realized_effects, realized_residual): """ Run E-step. """ try: solver = build_solver( realized_effects, realized_residual, self.preconditioner, self.cg_maxiter, force_iterative=self.force_iterative, ) # If the solver fell back to IterativeSolver (due to OOM or math conditions), # flag it to prevent wasting time on Woodbury attempts in future EM iterations. self.force_iterative = solver.is_iterative prec_resid = solver.solve(marginal_residual) current_log_lh = self._compute_log_likelihood( marginal_residual, prec_resid, solver ) except (np.linalg.LinAlgError, RuntimeError, ValueError): warnings.warn( "Numerical instability encountered during E-step. Reverting to the best valid state.", RuntimeWarning, stacklevel=2, ) current_log_lh = -np.inf solver = None # Update convergence monitor self.convergence_monitor.update(current_log_lh, self) if self.convergence_monitor.is_converged: return None, None, None total_random_effect, mu = aggregate_random_effects(prec_resid, realized_effects) return total_random_effect, mu, solver def _m_step( self, marginal_residual: np.ndarray, total_random_effect: np.ndarray, mu: tuple[np.ndarray, ...], realized_effects: tuple[RealizedRandomEffect, ...], realized_residual: RealizedResidual, solver, iteration: int = 0, ): """ Run M-step. """ eps = marginal_residual - total_random_effect T_sum = np.zeros((self.m, self.m)) new_covs = [] for k, re in enumerate(realized_effects): T_k, W_k = self.variance_corrector.compute_correction( k, solver, n_probes=self.n_probes, iteration=iteration ) T_sum += T_k new_covs.append(re._compute_next_cov(mu[k], W_k)) # Update Terms via Realized Effects logic new_resid_cov = realized_residual._compute_next_cov(eps, T_sum) self.residual_term.set_cov(new_resid_cov) for k, new_cov in enumerate(new_covs): self.random_effect_terms[k].set_cov(new_cov) return self def _compute_marginal_residual(self, X, y, total_random_effect): """ Fit FE model and compute new marginal residual. """ y_adj = y - total_random_effect y_adj = y_adj if self.m != 1 else y_adj.ravel() if self.has_validation: X_train = X[self.train_idx] y_adj_train = y_adj[self.train_idx] X_val = X[self.val_idx] y_adj_val = y_adj[self.val_idx] self.fe_model.fit(X_train, y_adj_train, X_val=X_val, y_val=y_adj_val) else: self.fe_model.fit(X, y_adj) fx = self.fe_model.predict(X) fx = fx if self.m != 1 else fx[:, None] return (y - fx).T.ravel() def _compute_log_likelihood(self, marginal_residual, prec_resid, solver): """ Compute log-likelihood. Routes to exact Matrix Determinant Lemma (when WoodburySolver is active) or Stochastic Lanczos Quadrature (when IterativeSolver is active). """ log_det_V = solver.logdet( slq_steps=self.slq_steps, n_probes=self.n_probes, n_jobs=self.n_jobs, backend=self.backend, ) log_likelihood = ( -( self.m * self.n * np.log(2 * np.pi) + log_det_V + marginal_residual.T @ prec_resid ) / 2 ) return log_likelihood