From ffdc82e204246a7880a062f6a55853a744f79cb4 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Tue, 21 Mar 2023 12:55:52 +0100 Subject: [PATCH 1/9] Bugfix in ScaledLinearOperator --- src/linpde_gp/linfuncops/_arithmetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/linpde_gp/linfuncops/_arithmetic.py b/src/linpde_gp/linfuncops/_arithmetic.py index 6c5b7941..265fb667 100644 --- a/src/linpde_gp/linfuncops/_arithmetic.py +++ b/src/linpde_gp/linfuncops/_arithmetic.py @@ -46,7 +46,7 @@ def _( def __rmul__(self, other) -> LinearFunctionOperator: if np.ndim(other) == 0: return ScaledLinearFunctionOperator( - linfuncop=self._linfuncop, + self._linfuncop, scalar=np.asarray(other) * self._scalar, ) From 2146f12cc52f25486ec52f6b7ef33b0ff09ea6ad Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Tue, 21 Mar 2023 12:56:17 +0100 Subject: [PATCH 2/9] Linop impl for Zero CovarianceFunction --- src/linpde_gp/randprocs/covfuncs/_zero.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/linpde_gp/randprocs/covfuncs/_zero.py b/src/linpde_gp/randprocs/covfuncs/_zero.py index 8ce951c8..050df075 100644 --- a/src/linpde_gp/randprocs/covfuncs/_zero.py +++ b/src/linpde_gp/randprocs/covfuncs/_zero.py @@ -1,5 +1,7 @@ +from typing import Optional from jax import numpy as jnp import numpy as np +import probnum as pn from probnum.randprocs import covfuncs from ._jax import JaxCovarianceFunctionMixin @@ -25,3 +27,10 @@ def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: broadcast_batch_shape + self.output_shape_0 + self.output_shape_1, dtype=np.result_type(x0, x1), ) + + def _evaluate_linop(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> pn.linops.LinearOperator: + shape = ( + self.output_size_0 * x0.shape[0], + self.output_size_1 * (x1.shape[0] if x1 is not None else x0.shape[0]), + ) + return pn.linops.Zero(shape, np.promote_types(x0.dtype, x1.dtype)) \ No newline at end of file From 12ea37a28303259843daa398abdd82e8afdecab2 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Tue, 21 Mar 2023 13:02:13 +0100 Subject: [PATCH 3/9] Add GPSolver Provides an abstract framework for GP solvers that provide their own methods for calculating the representer weights and the posterior covariance. --- .../_gaussian_process/solvers/_gp_solver.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py diff --git a/src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py b/src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py new file mode 100644 index 00000000..4bfa4d37 --- /dev/null +++ b/src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py @@ -0,0 +1,75 @@ +import abc +from collections.abc import Sequence + +import numpy as np +import probnum as pn +from linpde_gp.linops import BlockMatrix2x2 +from linpde_gp.linfunctls import LinearFunctional +from dataclasses import dataclass + + +@dataclass +class GPInferenceParams: + prior_mean: pn.functions.Function + prior_gram: pn.linops.LinearOperator + Ys: Sequence[np.ndarray] + Ls: Sequence[LinearFunctional] + bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None] + prior_representer_weights: np.ndarray + + +class ConcreteGPSolver(abc.ABC): + def __init__(self, gp_params: GPInferenceParams): + self._gp_params = gp_params + + def compute_representer_weights(self): + try: + return self._representer_weights + except AttributeError: + self._representer_weights = self._compute_representer_weights() + return self._representer_weights + + @abc.abstractmethod + def _compute_representer_weights(self): + raise NotImplementedError + + def _get_residual(self, Y, L, b): + return np.reshape( + ( + (Y - L(self._gp_params.prior_mean).reshape(-1, order="C")) + if b is None + else ( + Y + - L(self._gp_params.prior_mean).reshape(-1, order="C") + - b.mean.reshape(-1, order="C") + ) + ), + (-1,), + order="C", + ) + + def _get_full_residual(self): + return np.concatenate( + [ + self._get_residual(Y, L, b) + for Y, L, b in zip( + self._gp_params.Ys, self._gp_params.Ls, self._gp_params.bs + ) + ], + axis=-1, + ) + + @abc.abstractmethod + def compute_posterior_cov( + self, k_xx: np.ndarray, k_xX: np.ndarray, k_Xx: np.ndarray + ): + raise NotImplementedError + + +class GPSolver(abc.ABC): + def __init__(self): + pass + + @abc.abstractmethod + def get_concrete_solver(self, gp_params: GPInferenceParams) -> ConcreteGPSolver: + raise NotImplementedError From 086167ab74a9c805a7223e6323c5e502599fcd4f Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Tue, 21 Mar 2023 13:04:05 +0100 Subject: [PATCH 4/9] Add CholeskySolver --- .../_gaussian_process/solvers/_cholesky.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py diff --git a/src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py b/src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py new file mode 100644 index 00000000..d6d64a6b --- /dev/null +++ b/src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py @@ -0,0 +1,43 @@ +from ._gp_solver import GPSolver, ConcreteGPSolver, GPInferenceParams +import probnum as pn +import numpy as np +from linpde_gp.linops import BlockMatrix2x2 + + +class ConcreteCholeskySolver(ConcreteGPSolver): + def __init__(self, gp_params: GPInferenceParams): + super().__init__(gp_params) + + def _compute_representer_weights(self): + if self._gp_params.prior_representer_weights is not None: + # Update existing representer weights + assert isinstance(self._gp_params.prior_gram, BlockMatrix2x2) + new_residual = self._get_residual( + self._gp_params.Ys[-1], self._gp_params.Ls[-1], self._gp_params.bs[-1] + ) + return self._gp_params.prior_gram.schur_update( + self._gp_params.prior_representer_weights, new_residual + ) + full_residual = self._get_full_residual() + return self._gp_params.prior_gram.solve(full_residual) + + def compute_posterior_cov( + self, k_xx: np.ndarray, k_x0_X: np.ndarray, k_x1_X: np.ndarray + ): + return ( + k_xx + - ( + k_x0_X[..., None, :] + @ (self._gp_params.prior_gram.solve(k_x1_X[..., None])) + )[..., 0, 0] + ) + + +class CholeskySolver(GPSolver): + def __init__(self): + super().__init__() + + def get_concrete_solver( + self, gp_params: GPInferenceParams + ) -> ConcreteCholeskySolver: + return ConcreteCholeskySolver(gp_params) From 3274df1d621d2ab62eca177d4f8bc78f98d0cccb Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Tue, 21 Mar 2023 13:05:04 +0100 Subject: [PATCH 5/9] Use abstract solvers for inference --- .../_gaussian_process/_conditional.py | 132 +++++----------- .../_gaussian_process/solvers/_cholesky.py | 43 ------ .../_gaussian_process/solvers/_gp_solver.py | 75 --------- src/linpde_gp/solvers/__init__.py | 2 + src/linpde_gp/solvers/_cholesky.py | 84 ++++++++++ src/linpde_gp/solvers/_gp_solver.py | 143 ++++++++++++++++++ src/linpde_gp/solvers/covfuncs/__init__.py | 1 + src/linpde_gp/solvers/covfuncs/_downdate.py | 40 +++++ 8 files changed, 311 insertions(+), 209 deletions(-) delete mode 100644 src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py delete mode 100644 src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py create mode 100644 src/linpde_gp/solvers/__init__.py create mode 100644 src/linpde_gp/solvers/_cholesky.py create mode 100644 src/linpde_gp/solvers/_gp_solver.py create mode 100644 src/linpde_gp/solvers/covfuncs/__init__.py create mode 100644 src/linpde_gp/solvers/covfuncs/_downdate.py diff --git a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py index 7bcd75f9..aa18bb62 100644 --- a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py +++ b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py @@ -2,23 +2,26 @@ from collections.abc import Iterator, Sequence import functools -from typing import Optional import jax import jax.numpy as jnp import numpy as np from numpy.typing import ArrayLike import probnum as pn -import scipy.linalg from linpde_gp import linfunctls from linpde_gp.functions import JaxFunction from linpde_gp.linfuncops import LinearFunctionOperator from linpde_gp.linfunctls import LinearFunctional from linpde_gp.linops import BlockMatrix, BlockMatrix2x2 -from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance from linpde_gp.typing import RandomVariableLike +from linpde_gp.solvers import ( + GPSolver, + GPInferenceParams, + ConcreteGPSolver, + CholeskySolver, +) class ConditionalGaussianProcess(pn.randprocs.GaussianProcess): @@ -31,6 +34,7 @@ def from_observations( *, L: None | LinearFunctional | LinearFunctionOperator = None, b: None | RandomVariableLike = None, + solver: GPSolver = CholeskySolver(), ): Y, L, b, kLa, Lm, gram = cls._preprocess_observations( prior=prior, @@ -40,16 +44,18 @@ def from_observations( b=b, ) - representer_weights = gram.solve(Y - Lm) + kLas = ConditionalGaussianProcess._PriorPredictiveCrossCovariance((kLa,)) + inference_params = GPInferenceParams(prior, gram, (Y,), (L,), (b,), kLas, None) + concrete_solver = solver.get_concrete_solver(inference_params) return cls( prior=prior, Ys=(Y,), Ls=(L,), bs=(b,), - kLas=ConditionalGaussianProcess._PriorPredictiveCrossCovariance((kLa,)), + kLas=kLas, gram_matrix=gram, - representer_weights=representer_weights, + solver=concrete_solver, ) def __init__( @@ -61,7 +67,7 @@ def __init__( bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None], kLas: ConditionalGaussianProcess._PriorPredictiveCrossCovariance, gram_matrix: pn.linops.LinearOperator, - representer_weights: np.ndarray | None = None, + solver: ConcreteGPSolver, ): self._prior = prior @@ -72,39 +78,29 @@ def __init__( self._kLas = kLas self._gram_matrix = gram_matrix - - self._representer_weights = representer_weights + self._solver = solver super().__init__( mean=ConditionalGaussianProcess.Mean( prior_mean=self._prior.mean, kLas=self._kLas, - representer_weights=self.representer_weights, - ), - cov=ConditionalGaussianProcess.CovarianceFunction( - prior_covfunc=self._prior.cov, - kLas=self._kLas, - gram_matrix=self.gram, + solver=solver, ), + cov=self._solver.posterior_cov, ) @functools.cached_property def gram(self) -> pn.linops.LinearOperator: return self._gram_matrix + @property + def solver(self) -> ConcreteGPSolver: + return self._solver + @property def representer_weights(self) -> np.ndarray: if self._representer_weights is None: - y = np.concatenate( - [ - (Y - L(self._prior.mean)) - if b is None - else (Y - L(self._prior.mean) - b.mean.reshape(-1, order="C")) - for Y, L, b in zip(self._Ys, self._Ls, self._bs) - ], - axis=-1, - ) - self._representer_weights = self.gram.solve(y) + self._representer_weights = self._solver.compute_representer_weights() return self._representer_weights @@ -178,11 +174,11 @@ def __init__( self, prior_mean: JaxFunction, kLas: ConditionalGaussianProcess._PriorPredictiveCrossCovariance, - representer_weights: np.ndarray, + solver: ConcreteGPSolver, ): self._prior_mean = prior_mean self._kLas = kLas - self._representer_weights = representer_weights + self._solver = solver super().__init__( input_shape=self._prior_mean.input_shape, @@ -193,61 +189,14 @@ def _evaluate(self, x: np.ndarray) -> np.ndarray: m_x = self._prior_mean(x) kLas_x = self._kLas(x) - return m_x + kLas_x @ self._representer_weights + return m_x + kLas_x @ self._solver.compute_representer_weights() @functools.partial(jax.jit, static_argnums=0) def _evaluate_jax(self, x: jnp.ndarray) -> jnp.ndarray: m_x = self._prior_mean.jax(x) kLas_x = self._kLas.jax(x) - return m_x + kLas_x @ self._representer_weights - - class CovarianceFunction(JaxCovarianceFunction): - def __init__( - self, - prior_covfunc: JaxCovarianceFunction, - kLas: ConditionalGaussianProcess._PriorPredictiveCrossCovariance, - gram_matrix: pn.linops.LinearOperator, - ): - self._prior_covfunc = prior_covfunc - self._kLas = kLas - self._gram_matrix = gram_matrix - - super().__init__( - input_shape=self._prior_covfunc.input_shape, - output_shape_0=self._prior_covfunc.output_shape_0, - output_shape_1=self._prior_covfunc.output_shape_1, - ) - - def _evaluate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: - k_xx = self._prior_covfunc(x0, x1) - kLas_x0 = self._kLas(x0) - kLas_x1 = self._kLas(x1) if x1 is not None else kLas_x0 - cov_update = ( - kLas_x0[..., None, :] @ (self._gram_matrix.solve(kLas_x1[..., None])) - )[..., 0, 0] - - return k_xx - cov_update - - @functools.partial(jax.jit, static_argnums=0) - def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: - k_xx = self._prior_covfunc.jax(x0, x1) - kLas_x0 = self._kLas.jax(x0) - kLas_x1 = self._kLas.jax(x1) if x1 is not None else kLas_x0 - cov_update = ( - kLas_x0[..., None, :] - @ (self._gram_matrix.solve(kLas_x1[..., None]))[..., 0, 0] - ) - - return k_xx - cov_update - - def _evaluate_linop( - self, x0: np.ndarray, x1: Optional[np.ndarray] - ) -> pn.linops.LinearOperator: - k_xx = self._prior_covfunc.linop(x0, x1) - kLas_x0 = self._kLas.evaluate_linop(x0) - kLas_x1 = self._kLas.evaluate_linop(x1) if x1 is not None else kLas_x0 - return k_xx - kLas_x0 @ self._gram_matrix.solve(kLas_x1.T) + return m_x + kLas_x @ self._solver.compute_representer_weights() def condition_on_observations( self, @@ -256,6 +205,7 @@ def condition_on_observations( *, L: LinearFunctional | LinearFunctionOperator | None = None, b: RandomVariableLike | None = None, + solver: GPSolver = CholeskySolver(), ): Y, L, b, kLa, pred_mean, gram = self._preprocess_observations( prior=self._prior, @@ -278,18 +228,27 @@ def condition_on_observations( gram, is_spd=True, ) - representer_weights = gram_matrix.schur_update( - self.representer_weights, Y - pred_mean + + kLas = self._kLas.append(kLa) + inference_params = GPInferenceParams( + self._prior, + gram_matrix, + self._Ys + (Y,), + self._Ls + (L,), + self._bs + (b,), + kLas, + None, ) + concrete_solver = solver.get_concrete_solver(inference_params) return ConditionalGaussianProcess( prior=self._prior, Ys=self._Ys + (Y,), Ls=self._Ls + (L,), bs=self._bs + (b,), - kLas=self._kLas.append(kLa), + kLas=kLas, gram_matrix=gram_matrix, - representer_weights=representer_weights, + solver=concrete_solver, ) @classmethod @@ -438,7 +397,7 @@ def _( bs=conditional_gp._bs, kLas=self(conditional_gp._kLas), gram_matrix=conditional_gp.gram, - representer_weights=conditional_gp.representer_weights, + solver=conditional_gp.solver, ) @@ -454,16 +413,7 @@ def _( crosscov = self(conditional_gp._kLas) mean = linfunctl_prior.mean + crosscov @ conditional_gp.representer_weights + # TODO: Make this compatible with non-Cholesky solvers? cov = linfunctl_prior.cov - crosscov @ conditional_gp.gram.inv() @ crosscov.T return pn.randvars.Normal(mean, cov) - - -def cho_solve(L, b): - """Fixes a bug in scipy.linalg.cho_solve""" - (L, lower) = L - - if L.shape == (1, 1) and b.shape[0] == 1: - return b / L[0, 0] ** 2 - - return scipy.linalg.cho_solve((L, lower), b) diff --git a/src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py b/src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py deleted file mode 100644 index d6d64a6b..00000000 --- a/src/linpde_gp/randprocs/_gaussian_process/solvers/_cholesky.py +++ /dev/null @@ -1,43 +0,0 @@ -from ._gp_solver import GPSolver, ConcreteGPSolver, GPInferenceParams -import probnum as pn -import numpy as np -from linpde_gp.linops import BlockMatrix2x2 - - -class ConcreteCholeskySolver(ConcreteGPSolver): - def __init__(self, gp_params: GPInferenceParams): - super().__init__(gp_params) - - def _compute_representer_weights(self): - if self._gp_params.prior_representer_weights is not None: - # Update existing representer weights - assert isinstance(self._gp_params.prior_gram, BlockMatrix2x2) - new_residual = self._get_residual( - self._gp_params.Ys[-1], self._gp_params.Ls[-1], self._gp_params.bs[-1] - ) - return self._gp_params.prior_gram.schur_update( - self._gp_params.prior_representer_weights, new_residual - ) - full_residual = self._get_full_residual() - return self._gp_params.prior_gram.solve(full_residual) - - def compute_posterior_cov( - self, k_xx: np.ndarray, k_x0_X: np.ndarray, k_x1_X: np.ndarray - ): - return ( - k_xx - - ( - k_x0_X[..., None, :] - @ (self._gp_params.prior_gram.solve(k_x1_X[..., None])) - )[..., 0, 0] - ) - - -class CholeskySolver(GPSolver): - def __init__(self): - super().__init__() - - def get_concrete_solver( - self, gp_params: GPInferenceParams - ) -> ConcreteCholeskySolver: - return ConcreteCholeskySolver(gp_params) diff --git a/src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py b/src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py deleted file mode 100644 index 4bfa4d37..00000000 --- a/src/linpde_gp/randprocs/_gaussian_process/solvers/_gp_solver.py +++ /dev/null @@ -1,75 +0,0 @@ -import abc -from collections.abc import Sequence - -import numpy as np -import probnum as pn -from linpde_gp.linops import BlockMatrix2x2 -from linpde_gp.linfunctls import LinearFunctional -from dataclasses import dataclass - - -@dataclass -class GPInferenceParams: - prior_mean: pn.functions.Function - prior_gram: pn.linops.LinearOperator - Ys: Sequence[np.ndarray] - Ls: Sequence[LinearFunctional] - bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None] - prior_representer_weights: np.ndarray - - -class ConcreteGPSolver(abc.ABC): - def __init__(self, gp_params: GPInferenceParams): - self._gp_params = gp_params - - def compute_representer_weights(self): - try: - return self._representer_weights - except AttributeError: - self._representer_weights = self._compute_representer_weights() - return self._representer_weights - - @abc.abstractmethod - def _compute_representer_weights(self): - raise NotImplementedError - - def _get_residual(self, Y, L, b): - return np.reshape( - ( - (Y - L(self._gp_params.prior_mean).reshape(-1, order="C")) - if b is None - else ( - Y - - L(self._gp_params.prior_mean).reshape(-1, order="C") - - b.mean.reshape(-1, order="C") - ) - ), - (-1,), - order="C", - ) - - def _get_full_residual(self): - return np.concatenate( - [ - self._get_residual(Y, L, b) - for Y, L, b in zip( - self._gp_params.Ys, self._gp_params.Ls, self._gp_params.bs - ) - ], - axis=-1, - ) - - @abc.abstractmethod - def compute_posterior_cov( - self, k_xx: np.ndarray, k_xX: np.ndarray, k_Xx: np.ndarray - ): - raise NotImplementedError - - -class GPSolver(abc.ABC): - def __init__(self): - pass - - @abc.abstractmethod - def get_concrete_solver(self, gp_params: GPInferenceParams) -> ConcreteGPSolver: - raise NotImplementedError diff --git a/src/linpde_gp/solvers/__init__.py b/src/linpde_gp/solvers/__init__.py new file mode 100644 index 00000000..ee450e2a --- /dev/null +++ b/src/linpde_gp/solvers/__init__.py @@ -0,0 +1,2 @@ +from ._gp_solver import GPInferenceParams, GPSolver, ConcreteGPSolver +from ._cholesky import CholeskySolver diff --git a/src/linpde_gp/solvers/_cholesky.py b/src/linpde_gp/solvers/_cholesky.py new file mode 100644 index 00000000..f5810742 --- /dev/null +++ b/src/linpde_gp/solvers/_cholesky.py @@ -0,0 +1,84 @@ +from ._gp_solver import GPSolver, ConcreteGPSolver, GPInferenceParams +from .covfuncs import DowndateCovarianceFunction +import numpy as np +from jax import numpy as jnp +from linpde_gp.linops import BlockMatrix2x2 +from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction +from typing import Optional + + +class CholeskyCovarianceFunction(DowndateCovarianceFunction): + def __init__(self, gp_params: GPInferenceParams): + self._gp_params = gp_params + super().__init__(gp_params.prior.cov) + + def _downdate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: + kLas_x0 = self._gp_params.kLas(x0) + kLas_x1 = self._gp_params.kLas(x1) if x1 is not None else kLas_x0 + return ( + kLas_x0[..., None, :] + @ (self._gp_params.prior_gram.solve(kLas_x1[..., None])) + )[..., 0, 0] + + def _downdate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: + kLas_x0 = self._gp_params.kLas.jax(x0) + kLas_x1 = self._gp_params.kLas.jax(x1) if x1 is not None else kLas_x0 + return ( + kLas_x0[..., None, :] + @ (self._gp_params.prior_gram.solve(kLas_x1[..., None])) + )[..., 0, 0] + + +class ConcreteCholeskySolver(ConcreteGPSolver): + """ + Concrete solver that uses the Cholesky decomposition. + + Uses a block Cholesky decomposition if possible. + """ + + def __init__( + self, + gp_params: GPInferenceParams, + load_path: Optional[str] = None, + save_path: Optional[str] = None, + ): + super().__init__(gp_params, load_path, save_path) + + def _compute_representer_weights(self): + if self._gp_params.prior_representer_weights is not None: + # Update existing representer weights + assert isinstance(self._gp_params.prior_gram, BlockMatrix2x2) + new_residual = self._get_residual( + self._gp_params.Ys[-1], self._gp_params.Ls[-1], self._gp_params.bs[-1] + ) + return self._gp_params.prior_gram.schur_update( + self._gp_params.prior_representer_weights, new_residual + ) + full_residual = self._get_full_residual() + return self._gp_params.prior_gram.solve(full_residual) + + @property + def posterior_cov(self) -> JaxCovarianceFunction: + return CholeskyCovarianceFunction(self._gp_params) + + def _save_state(self) -> dict: + # TODO: Actually save the Cholesky decomposition of the linear operator + state = {"representer_weights": self._representer_weights} + return state + + def _load_state(self, dict): + self._representer_weights = dict["representer_weights"] + + +class CholeskySolver(GPSolver): + """Solver that uses the Cholesky decomposition.""" + + def __init__( + self, load_path: Optional[str] = None, save_path: Optional[str] = None + ): + super().__init__(load_path, save_path) + + def get_concrete_solver( + self, gp_params: GPInferenceParams + ) -> ConcreteCholeskySolver: + return ConcreteCholeskySolver(gp_params, self._load_path, self._save_path) diff --git a/src/linpde_gp/solvers/_gp_solver.py b/src/linpde_gp/solvers/_gp_solver.py new file mode 100644 index 00000000..02c23842 --- /dev/null +++ b/src/linpde_gp/solvers/_gp_solver.py @@ -0,0 +1,143 @@ +import abc +from collections.abc import Sequence +from typing import Optional + +import pickle +import numpy as np +import probnum as pn +from linpde_gp.linfunctls import LinearFunctional +from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction +from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance +from dataclasses import dataclass + + +@dataclass +class GPInferenceParams: + """ + Parameters for affine Gaussian process inference. + """ + + prior: pn.randprocs.GaussianProcess + prior_gram: pn.linops.LinearOperator + Ys: Sequence[np.ndarray] + Ls: Sequence[LinearFunctional] + bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None] + kLas: ProcessVectorCrossCovariance + prior_representer_weights: np.ndarray + + +class ConcreteGPSolver(abc.ABC): + """Abstract base class for concrete Gaussian process solvers. + Concrete in the sense that we are dealing with one specific + instance of affine GP regression with a concrete GP and concrete + linear functionals.""" + + def __init__( + self, + gp_params: GPInferenceParams, + load_path: Optional[str] = None, + save_path: Optional[str] = None, + ): + self._gp_params = gp_params + self._load_path = load_path + self._save_path = save_path + + self._representer_weights = None + + if self._load_path is not None: + self.load() + + @abc.abstractmethod + def _compute_representer_weights(self): + """ + Compute the representer weights. + """ + raise NotImplementedError + + def compute_representer_weights(self): + """ + Compute representer weights, or directly return cached + result from previous computation. + """ + if self._representer_weights is None: + self._representer_weights = self._compute_representer_weights() + self.save() + return self._representer_weights + + @property + @abc.abstractmethod + def posterior_cov(self) -> JaxCovarianceFunction: + raise NotImplementedError + + def _get_residual(self, Y, L, b): + return np.reshape( + ( + (Y - L(self._gp_params.prior.mean).reshape(-1, order="C")) + if b is None + else ( + Y + - L(self._gp_params.prior.mean).reshape(-1, order="C") + - b.mean.reshape(-1, order="C") + ) + ), + (-1,), + order="C", + ) + + def _get_full_residual(self): + return np.concatenate( + [ + self._get_residual(Y, L, b) + for Y, L, b in zip( + self._gp_params.Ys, self._gp_params.Ls, self._gp_params.bs + ) + ], + axis=-1, + ) + + @abc.abstractmethod + def _save_state(self) -> dict: + """Save solver state to dict.""" + raise NotImplementedError + + @abc.abstractmethod + def _load_state(self, dict): + """Load solver state from dict.""" + raise NotImplementedError + + def save(self): + """Save solver state to file.""" + if self._save_path is None: + return + solver_state = self._save_state() + with open(self._save_path, "wb") as f: + pickle.dump(solver_state, f) + + def load(self): + """Load solver state from file.""" + if self._load_path is None: + return + with open(self._load_path, "rb") as f: + loaded_state = pickle.load(f) + self._load_state(loaded_state) + + +class GPSolver(abc.ABC): + """ + User-facing interface for Gaussian process solvers used to pass + hyperparameters. + """ + + def __init__( + self, load_path: Optional[str] = None, save_path: Optional[str] = None + ): + self._load_path = load_path + self._save_path = save_path + + @abc.abstractmethod + def get_concrete_solver(self, gp_params: GPInferenceParams) -> ConcreteGPSolver: + """ + Get concrete solver. + Subclasses must implement this method. + """ + raise NotImplementedError diff --git a/src/linpde_gp/solvers/covfuncs/__init__.py b/src/linpde_gp/solvers/covfuncs/__init__.py new file mode 100644 index 00000000..fe0146fe --- /dev/null +++ b/src/linpde_gp/solvers/covfuncs/__init__.py @@ -0,0 +1 @@ +from ._downdate import DowndateCovarianceFunction \ No newline at end of file diff --git a/src/linpde_gp/solvers/covfuncs/_downdate.py b/src/linpde_gp/solvers/covfuncs/_downdate.py new file mode 100644 index 00000000..21d414bd --- /dev/null +++ b/src/linpde_gp/solvers/covfuncs/_downdate.py @@ -0,0 +1,40 @@ +import abc +import functools +from typing import Optional +import numpy as np +import jax +import jax.numpy as jnp +from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction +from probnum.typing import ShapeLike + + +class DowndateCovarianceFunction(JaxCovarianceFunction): + """ + Covariance function that is obtained by downdating a prior covariance function. + """ + + def __init__(self, prior_cov: JaxCovarianceFunction): + self._prior_cov = prior_cov + super().__init__( + input_shape_0=prior_cov.input_shape_0, + input_shape_1=prior_cov.input_shape_1, + output_shape_0=prior_cov.output_shape_0, + output_shape_1=prior_cov.output_shape_1, + ) + + @abc.abstractmethod + def _downdate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: + raise NotImplementedError + + @abc.abstractmethod + def _downdate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: + raise NotImplementedError + + def _evaluate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray: + k_xx = self._prior_cov(x0, x1) + return k_xx - self._downdate(x0, x1) + + @functools.partial(jax.jit, static_argnums=0) + def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: + k_xx = self._prior_cov.jax(x0, x1) + return k_xx - self._downdate_jax(x0, x1) From 093f37cd9ae61f72b300351dcd0395ebab7221c0 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Wed, 3 May 2023 15:49:59 +0200 Subject: [PATCH 6/9] Sort imports --- .../randprocs/_gaussian_process/_conditional.py | 8 ++++---- src/linpde_gp/randprocs/covfuncs/_zero.py | 7 +++++-- src/linpde_gp/solvers/__init__.py | 2 +- src/linpde_gp/solvers/_cholesky.py | 11 +++++++---- src/linpde_gp/solvers/_gp_solver.py | 5 +++-- src/linpde_gp/solvers/covfuncs/__init__.py | 2 +- src/linpde_gp/solvers/covfuncs/_downdate.py | 6 +++--- 7 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py index 48db8bd3..b425a7b1 100644 --- a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py +++ b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py @@ -15,13 +15,13 @@ from linpde_gp.linfunctls import LinearFunctional from linpde_gp.linops import BlockMatrix, BlockMatrix2x2 from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance -from linpde_gp.typing import RandomVariableLike from linpde_gp.solvers import ( - GPSolver, - GPInferenceParams, - ConcreteGPSolver, CholeskySolver, + ConcreteGPSolver, + GPInferenceParams, + GPSolver, ) +from linpde_gp.typing import RandomVariableLike class ConditionalGaussianProcess(pn.randprocs.GaussianProcess): diff --git a/src/linpde_gp/randprocs/covfuncs/_zero.py b/src/linpde_gp/randprocs/covfuncs/_zero.py index 050df075..3aadd039 100644 --- a/src/linpde_gp/randprocs/covfuncs/_zero.py +++ b/src/linpde_gp/randprocs/covfuncs/_zero.py @@ -1,4 +1,5 @@ from typing import Optional + from jax import numpy as jnp import numpy as np import probnum as pn @@ -28,9 +29,11 @@ def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: dtype=np.result_type(x0, x1), ) - def _evaluate_linop(self, x0: np.ndarray, x1: Optional[np.ndarray]) -> pn.linops.LinearOperator: + def _evaluate_linop( + self, x0: np.ndarray, x1: Optional[np.ndarray] + ) -> pn.linops.LinearOperator: shape = ( self.output_size_0 * x0.shape[0], self.output_size_1 * (x1.shape[0] if x1 is not None else x0.shape[0]), ) - return pn.linops.Zero(shape, np.promote_types(x0.dtype, x1.dtype)) \ No newline at end of file + return pn.linops.Zero(shape, np.promote_types(x0.dtype, x1.dtype)) diff --git a/src/linpde_gp/solvers/__init__.py b/src/linpde_gp/solvers/__init__.py index ee450e2a..151fbf12 100644 --- a/src/linpde_gp/solvers/__init__.py +++ b/src/linpde_gp/solvers/__init__.py @@ -1,2 +1,2 @@ -from ._gp_solver import GPInferenceParams, GPSolver, ConcreteGPSolver from ._cholesky import CholeskySolver +from ._gp_solver import ConcreteGPSolver, GPInferenceParams, GPSolver diff --git a/src/linpde_gp/solvers/_cholesky.py b/src/linpde_gp/solvers/_cholesky.py index f5810742..589e7c62 100644 --- a/src/linpde_gp/solvers/_cholesky.py +++ b/src/linpde_gp/solvers/_cholesky.py @@ -1,10 +1,13 @@ -from ._gp_solver import GPSolver, ConcreteGPSolver, GPInferenceParams -from .covfuncs import DowndateCovarianceFunction -import numpy as np +from typing import Optional + from jax import numpy as jnp +import numpy as np + from linpde_gp.linops import BlockMatrix2x2 from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction -from typing import Optional + +from ._gp_solver import ConcreteGPSolver, GPInferenceParams, GPSolver +from .covfuncs import DowndateCovarianceFunction class CholeskyCovarianceFunction(DowndateCovarianceFunction): diff --git a/src/linpde_gp/solvers/_gp_solver.py b/src/linpde_gp/solvers/_gp_solver.py index 02c23842..e23f85a4 100644 --- a/src/linpde_gp/solvers/_gp_solver.py +++ b/src/linpde_gp/solvers/_gp_solver.py @@ -1,14 +1,15 @@ import abc from collections.abc import Sequence +from dataclasses import dataclass +import pickle from typing import Optional -import pickle import numpy as np import probnum as pn + from linpde_gp.linfunctls import LinearFunctional from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance -from dataclasses import dataclass @dataclass diff --git a/src/linpde_gp/solvers/covfuncs/__init__.py b/src/linpde_gp/solvers/covfuncs/__init__.py index fe0146fe..a3a0fc91 100644 --- a/src/linpde_gp/solvers/covfuncs/__init__.py +++ b/src/linpde_gp/solvers/covfuncs/__init__.py @@ -1 +1 @@ -from ._downdate import DowndateCovarianceFunction \ No newline at end of file +from ._downdate import DowndateCovarianceFunction diff --git a/src/linpde_gp/solvers/covfuncs/_downdate.py b/src/linpde_gp/solvers/covfuncs/_downdate.py index 21d414bd..ec27731e 100644 --- a/src/linpde_gp/solvers/covfuncs/_downdate.py +++ b/src/linpde_gp/solvers/covfuncs/_downdate.py @@ -1,11 +1,11 @@ import abc import functools -from typing import Optional -import numpy as np + import jax import jax.numpy as jnp +import numpy as np + from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction -from probnum.typing import ShapeLike class DowndateCovarianceFunction(JaxCovarianceFunction): From a87ce98a50f52b8cfb4fceaae1eecf3d593cc8bf Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Wed, 3 May 2023 17:24:08 +0200 Subject: [PATCH 7/9] Fix applying linfuncop to ConditionalGP --- .../randprocs/_gaussian_process/_conditional.py | 11 ++++++----- src/linpde_gp/solvers/_cholesky.py | 4 ++-- src/linpde_gp/solvers/_gp_solver.py | 7 +++++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py index b425a7b1..4609cce3 100644 --- a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py +++ b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py @@ -66,6 +66,7 @@ def __init__( kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance, gram_matrix: pn.linops.LinearOperator, solver: GPSolver, + full_representer_weights: np.ndarray | None = None, ): self._prior = prior @@ -77,7 +78,9 @@ def __init__( self._gram_matrix = gram_matrix - inference_params = GPInferenceParams(prior, gram_matrix, Ys, Ls, bs, kLas, None) + inference_params = GPInferenceParams( + prior, gram_matrix, Ys, Ls, bs, kLas, None, full_representer_weights + ) self._solver = solver.get_concrete_solver(inference_params) self._abstract_solver = solver @@ -104,10 +107,7 @@ def solver(self) -> ConcreteGPSolver: @property def representer_weights(self) -> np.ndarray: - if self._representer_weights is None: - self._representer_weights = self._solver.compute_representer_weights() - - return self._representer_weights + return self._solver.compute_representer_weights() class PriorPredictiveCrossCovariance(ProcessVectorCrossCovariance): def __init__( @@ -397,6 +397,7 @@ def _( kLas=self(conditional_gp._kLas), gram_matrix=conditional_gp.gram, solver=conditional_gp.abstract_solver, + full_representer_weights=conditional_gp.representer_weights, ) diff --git a/src/linpde_gp/solvers/_cholesky.py b/src/linpde_gp/solvers/_cholesky.py index 589e7c62..a5eb51a6 100644 --- a/src/linpde_gp/solvers/_cholesky.py +++ b/src/linpde_gp/solvers/_cholesky.py @@ -48,14 +48,14 @@ def __init__( super().__init__(gp_params, load_path, save_path) def _compute_representer_weights(self): - if self._gp_params.prior_representer_weights is not None: + if self._gp_params.prev_representer_weights is not None: # Update existing representer weights assert isinstance(self._gp_params.prior_gram, BlockMatrix2x2) new_residual = self._get_residual( self._gp_params.Ys[-1], self._gp_params.Ls[-1], self._gp_params.bs[-1] ) return self._gp_params.prior_gram.schur_update( - self._gp_params.prior_representer_weights, new_residual + self._gp_params.prev_representer_weights, new_residual ) full_residual = self._get_full_residual() return self._gp_params.prior_gram.solve(full_residual) diff --git a/src/linpde_gp/solvers/_gp_solver.py b/src/linpde_gp/solvers/_gp_solver.py index e23f85a4..0eb8b259 100644 --- a/src/linpde_gp/solvers/_gp_solver.py +++ b/src/linpde_gp/solvers/_gp_solver.py @@ -24,7 +24,8 @@ class GPInferenceParams: Ls: Sequence[LinearFunctional] bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None] kLas: ProcessVectorCrossCovariance - prior_representer_weights: np.ndarray + prev_representer_weights: Optional[np.ndarray] + full_representer_weights: Optional[np.ndarray] class ConcreteGPSolver(abc.ABC): @@ -43,7 +44,9 @@ def __init__( self._load_path = load_path self._save_path = save_path - self._representer_weights = None + # Typically None, but in some cases (e.g. applying a linear function + # operator to a trained GP), the representer weights are already known + self._representer_weights = self._gp_params.full_representer_weights if self._load_path is not None: self.load() From ccc90f8e3e78c1204e8982065fab99c9367b7464 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Wed, 3 May 2023 17:37:07 +0200 Subject: [PATCH 8/9] Pylint fixes --- .../_gaussian_process/_conditional.py | 11 ++++------ src/linpde_gp/solvers/_cholesky.py | 20 +++---------------- src/linpde_gp/solvers/_gp_solver.py | 19 ++++++------------ src/linpde_gp/solvers/covfuncs/_downdate.py | 4 +--- 4 files changed, 14 insertions(+), 40 deletions(-) diff --git a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py index 4609cce3..7249bc27 100644 --- a/src/linpde_gp/randprocs/_gaussian_process/_conditional.py +++ b/src/linpde_gp/randprocs/_gaussian_process/_conditional.py @@ -36,7 +36,7 @@ def from_observations( b: None | RandomVariableLike = None, solver: GPSolver = CholeskySolver(), ): - Y, L, b, kLa, Lm, gram = cls._preprocess_observations( + Y, L, b, kLa, gram = cls._preprocess_observations( prior=prior, Y=Y, X=X, @@ -212,7 +212,7 @@ def condition_on_observations( b: RandomVariableLike | None = None, solver: GPSolver = CholeskySolver(), ): - Y, L, b, kLa, pred_mean, gram = self._preprocess_observations( + Y, L, b, kLa, gram = self._preprocess_observations( prior=self._prior, Y=Y, X=X, @@ -310,11 +310,9 @@ def _preprocess_observations( Lf = L(prior) kLa = L(prior.cov, argnum=1) - # Compute predictive mean and covariance matrix - pred_mean = Lf.mean + # Compute predictive covariance matrix gram = Lf.cov - pred_mean = pred_mean.reshape(-1, order="C") # Check observations Y = np.asarray(Y) if ( @@ -343,13 +341,12 @@ def _preprocess_observations( assert Y.size == Lf.cov.shape[1] if b is not None: - pred_mean = pred_mean + np.asarray(b.mean).reshape(-1, order="C") gram = gram + pn.linops.aslinop(b.cov) gram.is_symmetric = True gram.is_positive_definite = True - return Y, L, b, kLa, pred_mean, gram + return Y, L, b, kLa, gram pn.randprocs.GaussianProcess.condition_on_observations = ( diff --git a/src/linpde_gp/solvers/_cholesky.py b/src/linpde_gp/solvers/_cholesky.py index a5eb51a6..4d6f030b 100644 --- a/src/linpde_gp/solvers/_cholesky.py +++ b/src/linpde_gp/solvers/_cholesky.py @@ -33,20 +33,11 @@ def _downdate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray: class ConcreteCholeskySolver(ConcreteGPSolver): - """ - Concrete solver that uses the Cholesky decomposition. + """Concrete solver that uses the Cholesky decomposition. Uses a block Cholesky decomposition if possible. """ - def __init__( - self, - gp_params: GPInferenceParams, - load_path: Optional[str] = None, - save_path: Optional[str] = None, - ): - super().__init__(gp_params, load_path, save_path) - def _compute_representer_weights(self): if self._gp_params.prev_representer_weights is not None: # Update existing representer weights @@ -69,18 +60,13 @@ def _save_state(self) -> dict: state = {"representer_weights": self._representer_weights} return state - def _load_state(self, dict): - self._representer_weights = dict["representer_weights"] + def _load_state(self, state: dict): + self._representer_weights = state["representer_weights"] class CholeskySolver(GPSolver): """Solver that uses the Cholesky decomposition.""" - def __init__( - self, load_path: Optional[str] = None, save_path: Optional[str] = None - ): - super().__init__(load_path, save_path) - def get_concrete_solver( self, gp_params: GPInferenceParams ) -> ConcreteCholeskySolver: diff --git a/src/linpde_gp/solvers/_gp_solver.py b/src/linpde_gp/solvers/_gp_solver.py index 0eb8b259..43dbff63 100644 --- a/src/linpde_gp/solvers/_gp_solver.py +++ b/src/linpde_gp/solvers/_gp_solver.py @@ -14,9 +14,7 @@ @dataclass class GPInferenceParams: - """ - Parameters for affine Gaussian process inference. - """ + """Parameters for affine Gaussian process inference.""" prior: pn.randprocs.GaussianProcess prior_gram: pn.linops.LinearOperator @@ -53,14 +51,11 @@ def __init__( @abc.abstractmethod def _compute_representer_weights(self): - """ - Compute the representer weights. - """ + """Compute the representer weights.""" raise NotImplementedError def compute_representer_weights(self): - """ - Compute representer weights, or directly return cached + """Compute representer weights, or directly return cached result from previous computation. """ if self._representer_weights is None: @@ -105,7 +100,7 @@ def _save_state(self) -> dict: raise NotImplementedError @abc.abstractmethod - def _load_state(self, dict): + def _load_state(self, state: dict): """Load solver state from dict.""" raise NotImplementedError @@ -127,8 +122,7 @@ def load(self): class GPSolver(abc.ABC): - """ - User-facing interface for Gaussian process solvers used to pass + """User-facing interface for Gaussian process solvers used to pass hyperparameters. """ @@ -140,8 +134,7 @@ def __init__( @abc.abstractmethod def get_concrete_solver(self, gp_params: GPInferenceParams) -> ConcreteGPSolver: - """ - Get concrete solver. + """Get concrete solver. Subclasses must implement this method. """ raise NotImplementedError diff --git a/src/linpde_gp/solvers/covfuncs/_downdate.py b/src/linpde_gp/solvers/covfuncs/_downdate.py index ec27731e..6bbf84d2 100644 --- a/src/linpde_gp/solvers/covfuncs/_downdate.py +++ b/src/linpde_gp/solvers/covfuncs/_downdate.py @@ -9,9 +9,7 @@ class DowndateCovarianceFunction(JaxCovarianceFunction): - """ - Covariance function that is obtained by downdating a prior covariance function. - """ + """Covariance function that is obtained by downdating a prior covariance function.""" def __init__(self, prior_cov: JaxCovarianceFunction): self._prior_cov = prior_cov From 15ca50a5003e0d6064acbbfb93bfb9802762a099 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Wed, 3 May 2023 17:42:46 +0200 Subject: [PATCH 9/9] Final pylint fixes --- src/linpde_gp/solvers/_cholesky.py | 2 -- src/linpde_gp/solvers/covfuncs/_downdate.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/linpde_gp/solvers/_cholesky.py b/src/linpde_gp/solvers/_cholesky.py index 4d6f030b..f60e83a3 100644 --- a/src/linpde_gp/solvers/_cholesky.py +++ b/src/linpde_gp/solvers/_cholesky.py @@ -1,5 +1,3 @@ -from typing import Optional - from jax import numpy as jnp import numpy as np diff --git a/src/linpde_gp/solvers/covfuncs/_downdate.py b/src/linpde_gp/solvers/covfuncs/_downdate.py index 6bbf84d2..50172ac3 100644 --- a/src/linpde_gp/solvers/covfuncs/_downdate.py +++ b/src/linpde_gp/solvers/covfuncs/_downdate.py @@ -9,7 +9,9 @@ class DowndateCovarianceFunction(JaxCovarianceFunction): - """Covariance function that is obtained by downdating a prior covariance function.""" + """Covariance function that is obtained by downdating + a prior covariance function. + """ def __init__(self, prior_cov: JaxCovarianceFunction): self._prior_cov = prior_cov