diff --git a/src/linpde_gp/linfuncops/_arithmetic.py b/src/linpde_gp/linfuncops/_arithmetic.py index 6c5b7941..265fb667 100644 --- a/src/linpde_gp/linfuncops/_arithmetic.py +++ b/src/linpde_gp/linfuncops/_arithmetic.py @@ -46,7 +46,7 @@ def _( def __rmul__(self, other) -> LinearFunctionOperator: if np.ndim(other) == 0: return ScaledLinearFunctionOperator( - linfuncop=self._linfuncop, + self._linfuncop, scalar=np.asarray(other) * self._scalar, ) diff --git a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py index 211c4bd0..7249bc27 100644 --- a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py +++ b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py @@ -2,22 +2,25 @@ from collections.abc import Iterator, Sequence import functools -from typing import Optional import jax import jax.numpy as jnp import numpy as np from numpy.typing import ArrayLike import probnum as pn -import scipy.linalg from linpde_gp import linfunctls from linpde_gp.functions import JaxFunction from linpde_gp.linfuncops import LinearFunctionOperator from linpde_gp.linfunctls import LinearFunctional from linpde_gp.linops import BlockMatrix, BlockMatrix2x2 -from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance +from linpde_gp.solvers import ( + CholeskySolver, + ConcreteGPSolver, + GPInferenceParams, + GPSolver, +) from linpde_gp.typing import RandomVariableLike @@ -31,8 +34,9 @@ def from_observations( *, L: None | LinearFunctional | LinearFunctionOperator = None, b: None | RandomVariableLike = None, + solver: GPSolver = CholeskySolver(), ): - Y, L, b, kLa, Lm, gram = cls._preprocess_observations( + Y, L, b, kLa, gram = cls._preprocess_observations( prior=prior, Y=Y, X=X, @@ -40,16 +44,16 @@ def from_observations( b=b, ) - representer_weights = gram.solve(Y - Lm) + kLas = ConditionalGaussianProcess.PriorPredictiveCrossCovariance((kLa,)) return cls( prior=prior, Ys=(Y,), Ls=(L,), bs=(b,), - kLas=ConditionalGaussianProcess.PriorPredictiveCrossCovariance((kLa,)), + kLas=kLas, gram_matrix=gram, - representer_weights=representer_weights, + solver=solver, ) def __init__( @@ -61,7 +65,8 @@ def __init__( bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None], kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance, gram_matrix: pn.linops.LinearOperator, - representer_weights: np.ndarray | None = None, + solver: GPSolver, + full_representer_weights: np.ndarray | None = None, ): self._prior = prior @@ -73,19 +78,19 @@ def __init__( self._gram_matrix = gram_matrix - self._representer_weights = representer_weights + inference_params = GPInferenceParams( + prior, gram_matrix, Ys, Ls, bs, kLas, None, full_representer_weights + ) + self._solver = solver.get_concrete_solver(inference_params) + self._abstract_solver = solver super().__init__( mean=ConditionalGaussianProcess.Mean( prior_mean=self._prior.mean, kLas=self._kLas, - representer_weights=self.representer_weights, - ), - cov=ConditionalGaussianProcess.CovarianceFunction( - prior_covfunc=self._prior.cov, - kLas=self._kLas, - gram_matrix=self.gram, + solver=self._solver, ), + cov=self._solver.posterior_cov, ) @functools.cached_property @@ -93,20 +98,16 @@ def gram(self) -> pn.linops.LinearOperator: return self._gram_matrix @property - def representer_weights(self) -> np.ndarray: - if self._representer_weights is None: - y = np.concatenate( - [ - (Y - L(self._prior.mean)) - if b is None - else (Y - L(self._prior.mean) - b.mean.reshape(-1, order="C")) - for Y, L, b in zip(self._Ys, self._Ls, self._bs) - ], - axis=-1, - ) - self._representer_weights = self.gram.solve(y) + def abstract_solver(self) -> GPSolver: + return self._abstract_solver - return self._representer_weights + @property + def solver(self) -> ConcreteGPSolver: + return self._solver + + @property + def representer_weights(self) -> np.ndarray: + return self._solver.compute_representer_weights() class PriorPredictiveCrossCovariance(ProcessVectorCrossCovariance): def __init__( @@ -178,11 +179,11 @@ def __init__( self, prior_mean: JaxFunction, kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance, - representer_weights: np.ndarray, + solver: ConcreteGPSolver, ): self._prior_mean = prior_mean self._kLas = kLas - self._representer_weights = representer_weights + self._solver = solver super().__init__( input_shape=self._prior_mean.input_shape, @@ -193,61 +194,14 @@ def _evaluate(self, x: np.ndarray) -> np.ndarray: m_x = self._prior_mean(x) kLas_x = self._kLas(x) - return m_x + kLas_x @ self._representer_weights + return m_x + kLas_x @ self._solver.compute_representer_weights() @functools.partial(jax.jit, static_argnums=0) def _evaluate_jax(self, x: jnp.ndarray) -> jnp.ndarray: m_x = self._prior_mean.jax(x) kLas_x = self._kLas.jax(x) - return m_x + kLas_x @ self._representer_weights - - class CovarianceFunction(JaxCovarianceFunction): - def __init__( - self, - prior_covfunc: JaxCovarianceFunction, - kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance, - gram_matrix: pn.linops.LinearOperator, - ): - self._prior_covfunc = prior_covfunc - self._kLas = kLas - self._gram_matrix = gram_matrix - - super().__init__( - input_shape=self._prior_covfunc.input_shape, - output_shape_0=self._prior_covfunc.output_shape_0, - output_shape_1=self._prior_covfunc.output_shape_1, - ) - - def _evaluate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: - k_xx = self._prior_covfunc(x0, x1) - kLas_x0 = self._kLas(x0) - kLas_x1 = self._kLas(x1) if x1 is not None else kLas_x0 - cov_update = ( - kLas_x0[..., None, :] @ (self._gram_matrix.solve(kLas_x1[..., None])) - )[..., 0, 0] - - return k_xx - cov_update - - @functools.partial(jax.jit, static_argnums=0) - def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: - k_xx = self._prior_covfunc.jax(x0, x1) - kLas_x0 = self._kLas.jax(x0) - kLas_x1 = self._kLas.jax(x1) if x1 is not None else kLas_x0 - cov_update = ( - kLas_x0[..., None, :] - @ (self._gram_matrix.solve(kLas_x1[..., None]))[..., 0, 0] - ) - - return k_xx - cov_update - - def _evaluate_linop( - self, x0: np.ndarray, x1: Optional[np.ndarray] - ) -> pn.linops.LinearOperator: - k_xx = self._prior_covfunc.linop(x0, x1) - kLas_x0 = self._kLas.evaluate_linop(x0) - kLas_x1 = self._kLas.evaluate_linop(x1) if x1 is not None else kLas_x0 - return k_xx - kLas_x0 @ self._gram_matrix.solve(kLas_x1.T) + return m_x + kLas_x @ self._solver.compute_representer_weights() def condition_on_observations( self, @@ -256,8 +210,9 @@ def condition_on_observations( *, L: LinearFunctional | LinearFunctionOperator | None = None, b: RandomVariableLike | None = None, + solver: GPSolver = CholeskySolver(), ): - Y, L, b, kLa, pred_mean, gram = self._preprocess_observations( + Y, L, b, kLa, gram = self._preprocess_observations( prior=self._prior, Y=Y, X=X, @@ -278,18 +233,17 @@ def condition_on_observations( gram, is_spd=True, ) - representer_weights = gram_matrix.schur_update( - self.representer_weights, Y - pred_mean - ) + + kLas = self._kLas.append(kLa) return ConditionalGaussianProcess( prior=self._prior, Ys=self._Ys + (Y,), Ls=self._Ls + (L,), bs=self._bs + (b,), - kLas=self._kLas.append(kLa), + kLas=kLas, gram_matrix=gram_matrix, - representer_weights=representer_weights, + solver=solver, ) @classmethod @@ -356,11 +310,9 @@ def _preprocess_observations( Lf = L(prior) kLa = L(prior.cov, argnum=1) - # Compute predictive mean and covariance matrix - pred_mean = Lf.mean + # Compute predictive covariance matrix gram = Lf.cov - pred_mean = pred_mean.reshape(-1, order="C") # Check observations Y = np.asarray(Y) if ( @@ -389,13 +341,12 @@ def _preprocess_observations( assert Y.size == Lf.cov.shape[1] if b is not None: - pred_mean = pred_mean + np.asarray(b.mean).reshape(-1, order="C") gram = gram + pn.linops.aslinop(b.cov) gram.is_symmetric = True gram.is_positive_definite = True - return Y, L, b, kLa, pred_mean, gram + return Y, L, b, kLa, gram pn.randprocs.GaussianProcess.condition_on_observations = ( @@ -442,7 +393,8 @@ def _( bs=conditional_gp._bs, kLas=self(conditional_gp._kLas), gram_matrix=conditional_gp.gram, - representer_weights=conditional_gp.representer_weights, + solver=conditional_gp.abstract_solver, + full_representer_weights=conditional_gp.representer_weights, ) @@ -458,16 +410,7 @@ def _( crosscov = self(conditional_gp._kLas) mean = linfunctl_prior.mean + crosscov @ conditional_gp.representer_weights + # TODO: Make this compatible with non-Cholesky solvers? cov = linfunctl_prior.cov - crosscov @ conditional_gp.gram.inv() @ crosscov.T return pn.randvars.Normal(mean, cov) - - -def cho_solve(L, b): - """Fixes a bug in scipy.linalg.cho_solve""" - (L, lower) = L - - if L.shape == (1, 1) and b.shape[0] == 1: - return b / L[0, 0] ** 2 - - return scipy.linalg.cho_solve((L, lower), b) diff --git a/src/linpde_gp/randprocs/covfuncs/_zero.py b/src/linpde_gp/randprocs/covfuncs/_zero.py index 8ce951c8..3aadd039 100644 --- a/src/linpde_gp/randprocs/covfuncs/_zero.py +++ b/src/linpde_gp/randprocs/covfuncs/_zero.py @@ -1,5 +1,8 @@ +from typing import Optional + from jax import numpy as jnp import numpy as np +import probnum as pn from probnum.randprocs import covfuncs from ._jax import JaxCovarianceFunctionMixin @@ -25,3 +28,12 @@ def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: broadcast_batch_shape + self.output_shape_0 + self.output_shape_1, dtype=np.result_type(x0, x1), ) + + def _evaluate_linop( + self, x0: np.ndarray, x1: Optional[np.ndarray] + ) -> pn.linops.LinearOperator: + shape = ( + self.output_size_0 * x0.shape[0], + self.output_size_1 * (x1.shape[0] if x1 is not None else x0.shape[0]), + ) + return pn.linops.Zero(shape, np.promote_types(x0.dtype, x1.dtype)) diff --git a/src/linpde_gp/solvers/__init__.py b/src/linpde_gp/solvers/__init__.py new file mode 100644 index 00000000..151fbf12 --- /dev/null +++ b/src/linpde_gp/solvers/__init__.py @@ -0,0 +1,2 @@ +from ._cholesky import CholeskySolver +from ._gp_solver import ConcreteGPSolver, GPInferenceParams, GPSolver diff --git a/src/linpde_gp/solvers/_cholesky.py b/src/linpde_gp/solvers/_cholesky.py new file mode 100644 index 00000000..f60e83a3 --- /dev/null +++ b/src/linpde_gp/solvers/_cholesky.py @@ -0,0 +1,71 @@ +from jax import numpy as jnp +import numpy as np + +from linpde_gp.linops import BlockMatrix2x2 +from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction + +from ._gp_solver import ConcreteGPSolver, GPInferenceParams, GPSolver +from .covfuncs import DowndateCovarianceFunction + + +class CholeskyCovarianceFunction(DowndateCovarianceFunction): + def __init__(self, gp_params: GPInferenceParams): + self._gp_params = gp_params + super().__init__(gp_params.prior.cov) + + def _downdate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: + kLas_x0 = self._gp_params.kLas(x0) + kLas_x1 = self._gp_params.kLas(x1) if x1 is not None else kLas_x0 + return ( + kLas_x0[..., None, :] + @ (self._gp_params.prior_gram.solve(kLas_x1[..., None])) + )[..., 0, 0] + + def _downdate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: + kLas_x0 = self._gp_params.kLas.jax(x0) + kLas_x1 = self._gp_params.kLas.jax(x1) if x1 is not None else kLas_x0 + return ( + kLas_x0[..., None, :] + @ (self._gp_params.prior_gram.solve(kLas_x1[..., None])) + )[..., 0, 0] + + +class ConcreteCholeskySolver(ConcreteGPSolver): + """Concrete solver that uses the Cholesky decomposition. + + Uses a block Cholesky decomposition if possible. + """ + + def _compute_representer_weights(self): + if self._gp_params.prev_representer_weights is not None: + # Update existing representer weights + assert isinstance(self._gp_params.prior_gram, BlockMatrix2x2) + new_residual = self._get_residual( + self._gp_params.Ys[-1], self._gp_params.Ls[-1], self._gp_params.bs[-1] + ) + return self._gp_params.prior_gram.schur_update( + self._gp_params.prev_representer_weights, new_residual + ) + full_residual = self._get_full_residual() + return self._gp_params.prior_gram.solve(full_residual) + + @property + def posterior_cov(self) -> JaxCovarianceFunction: + return CholeskyCovarianceFunction(self._gp_params) + + def _save_state(self) -> dict: + # TODO: Actually save the Cholesky decomposition of the linear operator + state = {"representer_weights": self._representer_weights} + return state + + def _load_state(self, state: dict): + self._representer_weights = state["representer_weights"] + + +class CholeskySolver(GPSolver): + """Solver that uses the Cholesky decomposition.""" + + def get_concrete_solver( + self, gp_params: GPInferenceParams + ) -> ConcreteCholeskySolver: + return ConcreteCholeskySolver(gp_params, self._load_path, self._save_path) diff --git a/src/linpde_gp/solvers/_gp_solver.py b/src/linpde_gp/solvers/_gp_solver.py new file mode 100644 index 00000000..43dbff63 --- /dev/null +++ b/src/linpde_gp/solvers/_gp_solver.py @@ -0,0 +1,140 @@ +import abc +from collections.abc import Sequence +from dataclasses import dataclass +import pickle +from typing import Optional + +import numpy as np +import probnum as pn + +from linpde_gp.linfunctls import LinearFunctional +from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction +from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance + + +@dataclass +class GPInferenceParams: + """Parameters for affine Gaussian process inference.""" + + prior: pn.randprocs.GaussianProcess + prior_gram: pn.linops.LinearOperator + Ys: Sequence[np.ndarray] + Ls: Sequence[LinearFunctional] + bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None] + kLas: ProcessVectorCrossCovariance + prev_representer_weights: Optional[np.ndarray] + full_representer_weights: Optional[np.ndarray] + + +class ConcreteGPSolver(abc.ABC): + """Abstract base class for concrete Gaussian process solvers. + Concrete in the sense that we are dealing with one specific + instance of affine GP regression with a concrete GP and concrete + linear functionals.""" + + def __init__( + self, + gp_params: GPInferenceParams, + load_path: Optional[str] = None, + save_path: Optional[str] = None, + ): + self._gp_params = gp_params + self._load_path = load_path + self._save_path = save_path + + # Typically None, but in some cases (e.g. applying a linear function + # operator to a trained GP), the representer weights are already known + self._representer_weights = self._gp_params.full_representer_weights + + if self._load_path is not None: + self.load() + + @abc.abstractmethod + def _compute_representer_weights(self): + """Compute the representer weights.""" + raise NotImplementedError + + def compute_representer_weights(self): + """Compute representer weights, or directly return cached + result from previous computation. + """ + if self._representer_weights is None: + self._representer_weights = self._compute_representer_weights() + self.save() + return self._representer_weights + + @property + @abc.abstractmethod + def posterior_cov(self) -> JaxCovarianceFunction: + raise NotImplementedError + + def _get_residual(self, Y, L, b): + return np.reshape( + ( + (Y - L(self._gp_params.prior.mean).reshape(-1, order="C")) + if b is None + else ( + Y + - L(self._gp_params.prior.mean).reshape(-1, order="C") + - b.mean.reshape(-1, order="C") + ) + ), + (-1,), + order="C", + ) + + def _get_full_residual(self): + return np.concatenate( + [ + self._get_residual(Y, L, b) + for Y, L, b in zip( + self._gp_params.Ys, self._gp_params.Ls, self._gp_params.bs + ) + ], + axis=-1, + ) + + @abc.abstractmethod + def _save_state(self) -> dict: + """Save solver state to dict.""" + raise NotImplementedError + + @abc.abstractmethod + def _load_state(self, state: dict): + """Load solver state from dict.""" + raise NotImplementedError + + def save(self): + """Save solver state to file.""" + if self._save_path is None: + return + solver_state = self._save_state() + with open(self._save_path, "wb") as f: + pickle.dump(solver_state, f) + + def load(self): + """Load solver state from file.""" + if self._load_path is None: + return + with open(self._load_path, "rb") as f: + loaded_state = pickle.load(f) + self._load_state(loaded_state) + + +class GPSolver(abc.ABC): + """User-facing interface for Gaussian process solvers used to pass + hyperparameters. + """ + + def __init__( + self, load_path: Optional[str] = None, save_path: Optional[str] = None + ): + self._load_path = load_path + self._save_path = save_path + + @abc.abstractmethod + def get_concrete_solver(self, gp_params: GPInferenceParams) -> ConcreteGPSolver: + """Get concrete solver. + Subclasses must implement this method. + """ + raise NotImplementedError diff --git a/src/linpde_gp/solvers/covfuncs/__init__.py b/src/linpde_gp/solvers/covfuncs/__init__.py new file mode 100644 index 00000000..a3a0fc91 --- /dev/null +++ b/src/linpde_gp/solvers/covfuncs/__init__.py @@ -0,0 +1 @@ +from ._downdate import DowndateCovarianceFunction diff --git a/src/linpde_gp/solvers/covfuncs/_downdate.py b/src/linpde_gp/solvers/covfuncs/_downdate.py new file mode 100644 index 00000000..50172ac3 --- /dev/null +++ b/src/linpde_gp/solvers/covfuncs/_downdate.py @@ -0,0 +1,40 @@ +import abc +import functools + +import jax +import jax.numpy as jnp +import numpy as np + +from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction + + +class DowndateCovarianceFunction(JaxCovarianceFunction): + """Covariance function that is obtained by downdating + a prior covariance function. + """ + + def __init__(self, prior_cov: JaxCovarianceFunction): + self._prior_cov = prior_cov + super().__init__( + input_shape_0=prior_cov.input_shape_0, + input_shape_1=prior_cov.input_shape_1, + output_shape_0=prior_cov.output_shape_0, + output_shape_1=prior_cov.output_shape_1, + ) + + @abc.abstractmethod + def _downdate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: + raise NotImplementedError + + @abc.abstractmethod + def _downdate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: + raise NotImplementedError + + def _evaluate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: + k_xx = self._prior_cov(x0, x1) + return k_xx - self._downdate(x0, x1) + + @functools.partial(jax.jit, static_argnums=0) + def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: + k_xx = self._prior_cov.jax(x0, x1) + return k_xx - self._downdate_jax(x0, x1)