Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/linpde_gp/linfuncops/_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _(
def __rmul__(self, other) -> LinearFunctionOperator:
if np.ndim(other) == 0:
return ScaledLinearFunctionOperator(
linfuncop=self._linfuncop,
self._linfuncop,
scalar=np.asarray(other) * self._scalar,
)

Expand Down
145 changes: 44 additions & 101 deletions src/linpde_gp/randprocs/_gaussian_process/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@

from collections.abc import Iterator, Sequence
import functools
from typing import Optional

import jax
import jax.numpy as jnp
import numpy as np
from numpy.typing import ArrayLike
import probnum as pn
import scipy.linalg

from linpde_gp import linfunctls
from linpde_gp.functions import JaxFunction
from linpde_gp.linfuncops import LinearFunctionOperator
from linpde_gp.linfunctls import LinearFunctional
from linpde_gp.linops import BlockMatrix, BlockMatrix2x2
from linpde_gp.randprocs.covfuncs import JaxCovarianceFunction
from linpde_gp.randprocs.crosscov import ProcessVectorCrossCovariance
from linpde_gp.solvers import (
CholeskySolver,
ConcreteGPSolver,
GPInferenceParams,
GPSolver,
)
from linpde_gp.typing import RandomVariableLike


Expand All @@ -31,25 +34,26 @@ def from_observations(
*,
L: None | LinearFunctional | LinearFunctionOperator = None,
b: None | RandomVariableLike = None,
solver: GPSolver = CholeskySolver(),
):
Y, L, b, kLa, Lm, gram = cls._preprocess_observations(
Y, L, b, kLa, gram = cls._preprocess_observations(
prior=prior,
Y=Y,
X=X,
L=L,
b=b,
)

representer_weights = gram.solve(Y - Lm)
kLas = ConditionalGaussianProcess.PriorPredictiveCrossCovariance((kLa,))

return cls(
prior=prior,
Ys=(Y,),
Ls=(L,),
bs=(b,),
kLas=ConditionalGaussianProcess.PriorPredictiveCrossCovariance((kLa,)),
kLas=kLas,
gram_matrix=gram,
representer_weights=representer_weights,
solver=solver,
)

def __init__(
Expand All @@ -61,7 +65,8 @@ def __init__(
bs: Sequence[pn.randvars.Normal | pn.randvars.Constant | None],
kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance,
gram_matrix: pn.linops.LinearOperator,
representer_weights: np.ndarray | None = None,
solver: GPSolver,
full_representer_weights: np.ndarray | None = None,
):
self._prior = prior

Expand All @@ -73,40 +78,36 @@ def __init__(

self._gram_matrix = gram_matrix

self._representer_weights = representer_weights
inference_params = GPInferenceParams(
prior, gram_matrix, Ys, Ls, bs, kLas, None, full_representer_weights
)
self._solver = solver.get_concrete_solver(inference_params)
self._abstract_solver = solver

super().__init__(
mean=ConditionalGaussianProcess.Mean(
prior_mean=self._prior.mean,
kLas=self._kLas,
representer_weights=self.representer_weights,
),
cov=ConditionalGaussianProcess.CovarianceFunction(
prior_covfunc=self._prior.cov,
kLas=self._kLas,
gram_matrix=self.gram,
solver=self._solver,
),
cov=self._solver.posterior_cov,
)

@functools.cached_property
def gram(self) -> pn.linops.LinearOperator:
return self._gram_matrix

@property
def representer_weights(self) -> np.ndarray:
if self._representer_weights is None:
y = np.concatenate(
[
(Y - L(self._prior.mean))
if b is None
else (Y - L(self._prior.mean) - b.mean.reshape(-1, order="C"))
for Y, L, b in zip(self._Ys, self._Ls, self._bs)
],
axis=-1,
)
self._representer_weights = self.gram.solve(y)
def abstract_solver(self) -> GPSolver:
return self._abstract_solver

return self._representer_weights
@property
def solver(self) -> ConcreteGPSolver:
return self._solver

@property
def representer_weights(self) -> np.ndarray:
return self._solver.compute_representer_weights()

class PriorPredictiveCrossCovariance(ProcessVectorCrossCovariance):
def __init__(
Expand Down Expand Up @@ -178,11 +179,11 @@ def __init__(
self,
prior_mean: JaxFunction,
kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance,
representer_weights: np.ndarray,
solver: ConcreteGPSolver,
):
self._prior_mean = prior_mean
self._kLas = kLas
self._representer_weights = representer_weights
self._solver = solver

super().__init__(
input_shape=self._prior_mean.input_shape,
Expand All @@ -193,61 +194,14 @@ def _evaluate(self, x: np.ndarray) -> np.ndarray:
m_x = self._prior_mean(x)
kLas_x = self._kLas(x)

return m_x + kLas_x @ self._representer_weights
return m_x + kLas_x @ self._solver.compute_representer_weights()

@functools.partial(jax.jit, static_argnums=0)
def _evaluate_jax(self, x: jnp.ndarray) -> jnp.ndarray:
m_x = self._prior_mean.jax(x)
kLas_x = self._kLas.jax(x)

return m_x + kLas_x @ self._representer_weights

class CovarianceFunction(JaxCovarianceFunction):
def __init__(
self,
prior_covfunc: JaxCovarianceFunction,
kLas: ConditionalGaussianProcess.PriorPredictiveCrossCovariance,
gram_matrix: pn.linops.LinearOperator,
):
self._prior_covfunc = prior_covfunc
self._kLas = kLas
self._gram_matrix = gram_matrix

super().__init__(
input_shape=self._prior_covfunc.input_shape,
output_shape_0=self._prior_covfunc.output_shape_0,
output_shape_1=self._prior_covfunc.output_shape_1,
)

def _evaluate(self, x0: np.ndarray, x1: np.ndarray | None) -> np.ndarray:
k_xx = self._prior_covfunc(x0, x1)
kLas_x0 = self._kLas(x0)
kLas_x1 = self._kLas(x1) if x1 is not None else kLas_x0
cov_update = (
kLas_x0[..., None, :] @ (self._gram_matrix.solve(kLas_x1[..., None]))
)[..., 0, 0]

return k_xx - cov_update

@functools.partial(jax.jit, static_argnums=0)
def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray:
k_xx = self._prior_covfunc.jax(x0, x1)
kLas_x0 = self._kLas.jax(x0)
kLas_x1 = self._kLas.jax(x1) if x1 is not None else kLas_x0
cov_update = (
kLas_x0[..., None, :]
@ (self._gram_matrix.solve(kLas_x1[..., None]))[..., 0, 0]
)

return k_xx - cov_update

def _evaluate_linop(
self, x0: np.ndarray, x1: Optional[np.ndarray]
) -> pn.linops.LinearOperator:
k_xx = self._prior_covfunc.linop(x0, x1)
kLas_x0 = self._kLas.evaluate_linop(x0)
kLas_x1 = self._kLas.evaluate_linop(x1) if x1 is not None else kLas_x0
return k_xx - kLas_x0 @ self._gram_matrix.solve(kLas_x1.T)
return m_x + kLas_x @ self._solver.compute_representer_weights()

def condition_on_observations(
self,
Expand All @@ -256,8 +210,9 @@ def condition_on_observations(
*,
L: LinearFunctional | LinearFunctionOperator | None = None,
b: RandomVariableLike | None = None,
solver: GPSolver = CholeskySolver(),
):
Y, L, b, kLa, pred_mean, gram = self._preprocess_observations(
Y, L, b, kLa, gram = self._preprocess_observations(
prior=self._prior,
Y=Y,
X=X,
Expand All @@ -278,18 +233,17 @@ def condition_on_observations(
gram,
is_spd=True,
)
representer_weights = gram_matrix.schur_update(
self.representer_weights, Y - pred_mean
)

kLas = self._kLas.append(kLa)

return ConditionalGaussianProcess(
prior=self._prior,
Ys=self._Ys + (Y,),
Ls=self._Ls + (L,),
bs=self._bs + (b,),
kLas=self._kLas.append(kLa),
kLas=kLas,
gram_matrix=gram_matrix,
representer_weights=representer_weights,
solver=solver,
)

@classmethod
Expand Down Expand Up @@ -356,11 +310,9 @@ def _preprocess_observations(
Lf = L(prior)
kLa = L(prior.cov, argnum=1)

# Compute predictive mean and covariance matrix
pred_mean = Lf.mean
# Compute predictive covariance matrix
gram = Lf.cov

pred_mean = pred_mean.reshape(-1, order="C")
# Check observations
Y = np.asarray(Y)
if (
Expand Down Expand Up @@ -389,13 +341,12 @@ def _preprocess_observations(
assert Y.size == Lf.cov.shape[1]

if b is not None:
pred_mean = pred_mean + np.asarray(b.mean).reshape(-1, order="C")
gram = gram + pn.linops.aslinop(b.cov)

gram.is_symmetric = True
gram.is_positive_definite = True

return Y, L, b, kLa, pred_mean, gram
return Y, L, b, kLa, gram


pn.randprocs.GaussianProcess.condition_on_observations = (
Expand Down Expand Up @@ -442,7 +393,8 @@ def _(
bs=conditional_gp._bs,
kLas=self(conditional_gp._kLas),
gram_matrix=conditional_gp.gram,
representer_weights=conditional_gp.representer_weights,
solver=conditional_gp.abstract_solver,
full_representer_weights=conditional_gp.representer_weights,
)


Expand All @@ -458,16 +410,7 @@ def _(
crosscov = self(conditional_gp._kLas)

mean = linfunctl_prior.mean + crosscov @ conditional_gp.representer_weights
# TODO: Make this compatible with non-Cholesky solvers?
cov = linfunctl_prior.cov - crosscov @ conditional_gp.gram.inv() @ crosscov.T

return pn.randvars.Normal(mean, cov)


def cho_solve(L, b):
"""Fixes a bug in scipy.linalg.cho_solve"""
(L, lower) = L

if L.shape == (1, 1) and b.shape[0] == 1:
return b / L[0, 0] ** 2

return scipy.linalg.cho_solve((L, lower), b)
12 changes: 12 additions & 0 deletions src/linpde_gp/randprocs/covfuncs/_zero.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Optional

from jax import numpy as jnp
import numpy as np
import probnum as pn
from probnum.randprocs import covfuncs

from ._jax import JaxCovarianceFunctionMixin
Expand All @@ -25,3 +28,12 @@ def _evaluate_jax(self, x0: jnp.ndarray, x1: jnp.ndarray | None) -> jnp.ndarray:
broadcast_batch_shape + self.output_shape_0 + self.output_shape_1,
dtype=np.result_type(x0, x1),
)

def _evaluate_linop(
self, x0: np.ndarray, x1: Optional[np.ndarray]
) -> pn.linops.LinearOperator:
shape = (
self.output_size_0 * x0.shape[0],
self.output_size_1 * (x1.shape[0] if x1 is not None else x0.shape[0]),
)
return pn.linops.Zero(shape, np.promote_types(x0.dtype, x1.dtype))
2 changes: 2 additions & 0 deletions src/linpde_gp/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._cholesky import CholeskySolver
from ._gp_solver import ConcreteGPSolver, GPInferenceParams, GPSolver
Loading