diff --git a/.gitignore b/.gitignore index 2aef808d..917cd8eb 100644 --- a/.gitignore +++ b/.gitignore @@ -38,17 +38,25 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +reports/ htmlcov/ .tox/ .coverage .coverage.* .cache +.junit.xml nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ +# notebooks +docs/notebooks/checkpoint +docs/notebooks/ckpts/ +docs/notebooks/logs/ +docs/notebooks/weights* + # Translations *.mo *.pot diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f7a9ad67..da214008 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -11,8 +11,12 @@ Because GitHub's [graph of contributors](http://github.com/secondmind-labs/GPflu [James A. Leedham](https://github.com/JamesALeedham) [Felix Leibfried](https://github.com/fleibfried), [John A. McLeod](https://github.com/johnamcleod), +[Jesper Nielsen](https://github.com/jesnie), +[Sebastian Ober](https://github.com/sebastianober), +[Sebastian Popescu](https://github.com/SebastianPopescu), [Hugh Salimbeni](https://github.com/hughsalimbeni), +[Hrvoje Stojic](https://github.com/hstojic), [Marcin B. Tomczak](https://github.com/marctom) -Feel free to add yourself when you first contribute to GPflux's code, tests, or documentation! \ No newline at end of file +Feel free to add yourself when you first contribute to GPflux's code, tests, or documentation! diff --git a/docs/conf.py b/docs/conf.py index 02eb90f6..d30e482f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -69,11 +69,11 @@ "python": ("https://docs.python.org/3/", None), "tensorflow": ( "https://www.tensorflow.org/api_docs/python", - "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv" + "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", ), "tensorflow_probability": ( "https://www.tensorflow.org/probability/api_docs/python", - "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tfp_py_objects.inv" + "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tfp_py_objects.inv", ), "gpflow": ("https://gpflow.readthedocs.io/en/master/", None), } @@ -116,7 +116,9 @@ } # If True, show link to rst source on rendered HTML pages -html_show_sourcelink = False # Remove 'view source code' from top of page (for html, not python) +html_show_sourcelink = ( + False # Remove 'view source code' from top of page (for html, not python) +) # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/notebooks/efficient_posterior_sampling.py b/docs/notebooks/efficient_posterior_sampling.py index cb9a561b..22d7a22b 100644 --- a/docs/notebooks/efficient_posterior_sampling.py +++ b/docs/notebooks/efficient_posterior_sampling.py @@ -80,7 +80,9 @@ from gpflow.models import GPR from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine -from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition +from gpflux.feature_decomposition_kernels.kernel_with_feature_decomposition import ( + KernelWithFeatureDecomposition, +) # %% [markdown] """ diff --git a/docs/notebooks/efficient_sampling.py b/docs/notebooks/efficient_sampling.py index 5b6a6da5..225f67ad 100644 --- a/docs/notebooks/efficient_sampling.py +++ b/docs/notebooks/efficient_sampling.py @@ -37,7 +37,7 @@ from gpflow.config import default_float from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine -from gpflux.sampling import KernelWithFeatureDecomposition +from gpflux.feature_decomposition_kernels import KernelWithFeatureDecomposition from gpflux.models.deep_gp import sample_dgp diff --git a/docs/notebooks/multi_output_efficient_sampling.py b/docs/notebooks/multi_output_efficient_sampling.py new file mode 100644 index 00000000..1783965b --- /dev/null +++ b/docs/notebooks/multi_output_efficient_sampling.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_markers: '"""' +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.4.2 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% [markdown] +""" +# Efficient sampling with Gaussian processes and Random Fourier Features + +Gaussian processes (GPs) provide a mathematically elegant framework for learning unknown functions from data. They are robust to overfitting, allow to incorporate prior assumptions into the model and provide calibrated uncertainty estimates for their predictions. This makes them prime candidates in settings where data is scarce, noisy or very costly to obtain, and are natural tools in applications such as Bayesian optimisation (BO). + +Despite their favorable properties, the use of GPs still has practical limitations. One of them is the computational complexity to draw predictive samples from the model, which quickly becomes prohibitive as the sample size grows, and creates a well-known bottleneck for GP-based Thompson sampling (GP-TS) for instance. +Recent work proposes to combine GP’s weight-space and function-space views to draw samples more efficiently from (approximate) posterior GPs with encouraging results in low-dimensional regimes. + +In GPflux, this functionality is unlocked by grouping a kernel (e.g., `gpflow.kernels.Matern52`) with its feature decomposition using `gpflux.sampling.KernelWithFeatureDecomposition`. See the notebooks on [weight space approximation](weight_space_approximation.ipynb) and [efficient posterior sampling](efficient_posterior_sampling.ipynb) for a thorough explanation. +""" +# %% +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt + +import gpflow +import gpflux + +from gpflow.config import default_float + +from gpflux.layers.basis_functions.fourier_features import MultiOutputRandomFourierFeaturesCosine +from gpflux.feature_decomposition_kernels import ( + KernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, + SeparateMultiOutputKernelWithFeatureDecomposition, +) +from gpflux.models.deep_gp import sample_dgp + +tf.keras.backend.set_floatx("float64") + +# %% [markdown] +""" +## Load Snelson dataset +""" + +# %% +d = np.load("../../tests/snelson1d.npz") +X, Y = data = d["X"], d["Y"] +num_data, input_dim = X.shape + +# %% [markdown] +r""" +## Setting up the kernel and its feature decomposition + +The `KernelWithFeatureDecomposition` instance represents a kernel together with its finite feature decomposition, +$$ +k(x, x') = \sum_{i=0}^L \lambda_i \phi_i(x) \phi_i(x'), +$$ +where $\lambda_i$ and $\phi_i(\cdot)$ are the coefficients (eigenvalues) and features (eigenfunctions), respectively, and $L$ is the finite cutoff. See [the notebook on weight space approximation](weight_space_approximation.ipynb) for a detailed explanation of how to construct this decomposition using Random Fourier Features (RFF). +""" + +# %% +# kernel = gpflow.kernels.Matern52() +kernel1 = gpflow.kernels.Matern52() +kernel2 = gpflow.kernels.SquaredExponential() +# kernel = gpflow.kernels.SeparateIndependent( kernels = [kernel1, kernel2]) +kernel = gpflow.kernels.SharedIndependent(kernel=kernel1, output_dim=2) + +Z_1 = np.linspace(X.min(), X.max(), 10).reshape(-1, 1).astype(np.float64) +Z_2 = np.linspace(X.min(), X.max(), 10).reshape(-1, 1).astype(np.float64) + +inducing_variable_1 = gpflow.inducing_variables.InducingPoints(Z_1) +inducing_variable_2 = gpflow.inducing_variables.InducingPoints(Z_2) +# inducing_variable = gpflow.inducing_variables.SeparateIndependentInducingVariables(inducing_variable_list= [inducing_variable_1, inducing_variable_2]) +inducing_variable = gpflow.inducing_variables.SharedIndependentInducingVariables( + inducing_variable=inducing_variable_1 +) + +gpflow.utilities.set_trainable(inducing_variable, False) +P = 2 +num_rff = 1000 +eigenfunctions = MultiOutputRandomFourierFeaturesCosine(kernel, num_rff, dtype=default_float()) +eigenvalues = np.ones((P, num_rff, 1), dtype=default_float()) +# kernel_with_features = SeparateMultiOutputKernelWithFeatureDecomposition(kernel, eigenfunctions, eigenvalues) +kernel_with_features = SharedMultiOutputKernelWithFeatureDecomposition( + kernel, eigenfunctions, eigenvalues +) +# %% [markdown] +""" +## Building and training the single-layer GP + +### Initialise the single-layer GP +Because `KernelWithFeatureDecomposition` is just a `gpflow.kernels.Kernel`, we can construct a GP layer with it. +""" +# %% +layer = gpflux.layers.GPLayer( + kernel_with_features, + inducing_variable, + num_data, + whiten=True, + num_latent_gps=2, + mean_function=gpflow.mean_functions.Zero(), +) +likelihood_layer = gpflux.layers.LikelihoodLayer(gpflow.likelihoods.Gaussian()) # noqa: E231 +dgp = gpflux.models.DeepGP([layer], likelihood_layer) +model = dgp.as_training_model() +# %% [markdown] +""" +### Fit model to data +""" + +# %% +model.compile(tf.optimizers.Adam(learning_rate=0.1)) + +callbacks = [ + tf.keras.callbacks.ReduceLROnPlateau( + monitor="loss", + patience=5, + factor=0.95, + verbose=0, + min_lr=1e-6, + ) +] + +history = model.fit( + {"inputs": X, "targets": tf.tile(Y, [1, 2])}, + batch_size=num_data, + epochs=100, + callbacks=callbacks, + verbose=0, +) +# %% [markdown] +""" +## Drawing samples + +Now that the model is trained we can draw efficient and consistent samples from the posterior GP. By "consistent" we mean that the `sample_dgp` function returns a function object that can be evaluated multiple times at different locations, but importantly, the returned function values will come from the same GP sample. This functionality is implemented by the `gpflux.sampling.efficient_sample` function. +""" + +# %% +from typing import Callable + +x_margin = 5 +n_x = 1000 +X_test = np.linspace(X.min() - x_margin, X.max() + x_margin, n_x).reshape(-1, 1) + +f_mean, f_var = dgp.predict_f(X_test) +f_scale = np.sqrt(f_var) + + +fig, axs = plt.subplots(1, 2) + + +for dim in range(2): + + # Plot samples + n_sim = 10 + for _ in range(n_sim): + # `sample_dgp` returns a callable - which we subsequently evaluate + f_sample: Callable[[tf.Tensor], tf.Tensor] = sample_dgp(dgp) + axs[dim].plot(X_test, f_sample(X_test).numpy()[..., dim]) + + # Plot GP mean and uncertainty intervals and data + axs[dim].plot(X_test, f_mean[..., dim], "C0") + axs[dim].plot(X_test, f_mean[..., dim] + f_scale[..., dim], "C0--") + axs[dim].plot(X_test, f_mean[..., dim] - f_scale[..., dim], "C0--") + axs[dim].plot(X, Y, "kx", alpha=0.2) + axs[dim].set_xlim(X.min() - x_margin, X.max() + x_margin) + axs[dim].set_ylim(Y.min() - x_margin, Y.max() + x_margin) +plt.show() diff --git a/docs/notebooks/weight_space_approximation.py b/docs/notebooks/weight_space_approximation.py index ef4b4fe4..512b395b 100644 --- a/docs/notebooks/weight_space_approximation.py +++ b/docs/notebooks/weight_space_approximation.py @@ -62,7 +62,7 @@ from gpflow.inducing_variables import InducingPoints from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine -from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition +from gpflux.feature_decomposition_kernels import KernelWithFeatureDecomposition # %% [markdown] """ diff --git a/gpflux/feature_decomposition_kernels/__init__.py b/gpflux/feature_decomposition_kernels/__init__.py new file mode 100644 index 00000000..b2dff876 --- /dev/null +++ b/gpflux/feature_decomposition_kernels/__init__.py @@ -0,0 +1,14 @@ +from .kernel_with_feature_decomposition import KernelWithFeatureDecomposition, _ApproximateKernel +from .multioutput import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, + _MultiOutputApproximateKernel, +) + +__all__ = [ + "_ApproximateKernel", + "KernelWithFeatureDecomposition", + "_MultiOutputApproximateKernel", + "SharedMultiOutputKernelWithFeatureDecomposition", + "SeparateMultiOutputKernelWithFeatureDecomposition", +] diff --git a/gpflux/sampling/kernel_with_feature_decomposition.py b/gpflux/feature_decomposition_kernels/kernel_with_feature_decomposition.py similarity index 94% rename from gpflux/sampling/kernel_with_feature_decomposition.py rename to gpflux/feature_decomposition_kernels/kernel_with_feature_decomposition.py index bad1f80c..54801a2c 100644 --- a/gpflux/sampling/kernel_with_feature_decomposition.py +++ b/gpflux/feature_decomposition_kernels/kernel_with_feature_decomposition.py @@ -58,10 +58,10 @@ def __init__( :param feature_functions: A Keras layer for which the call evaluates the ``L`` features of the kernel :math:`\phi_i(\cdot)`. For ``X`` with the shape ``[N, D]``, ``feature_functions(X)`` returns a tensor with the shape ``[N, L]``. - :param feature_coefficients: A tensor with the shape ``[L, 1]`` with coefficients + :param feature_coefficients: A tensor with the shape ``[L, 1]`' with coefficients associated with the features, :math:`\lambda_i`. """ - self._feature_functions = feature_functions + self._feature_functions = feature_functions # [N, L] self._feature_coefficients = feature_coefficients # [L, 1] def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor: @@ -72,19 +72,23 @@ def K(self, X: TensorType, X2: Optional[TensorType] = None) -> tf.Tensor: else: phi2 = self._feature_functions(X2) # [N2, L] - r = tf.matmul( - phi, tf.transpose(self._feature_coefficients) * phi2, transpose_b=True + r = tf.linalg.matmul( + phi, + tf.linalg.matrix_transpose(self._feature_coefficients) * phi2, + transpose_b=True, ) # [N, N2] N1, N2 = tf.shape(phi)[0], tf.shape(phi2)[0] + tf.debugging.assert_equal(tf.shape(r), [N1, N2]) return r def K_diag(self, X: TensorType) -> tf.Tensor: """Approximate the true kernel by an inner product between feature functions.""" phi_squared = self._feature_functions(X) ** 2 # [N, L] - r = tf.reduce_sum(phi_squared * tf.transpose(self._feature_coefficients), axis=1) # [N,] - N = tf.shape(X)[0] + r = tf.reduce_sum(phi_squared * tf.transpose(self._feature_coefficients), axis=-1) # [N,] + N = tf.shape(X)[0] if tf.experimental.numpy.ndim(X) == 1 else tf.shape(X)[0] + tf.debugging.assert_equal(tf.shape(r), [N]) # noqa: E231 return r @@ -156,8 +160,9 @@ def __init__( else: self._kernel = kernel - self._feature_functions = feature_functions + self._feature_functions = feature_functions # [N, L] self._feature_coefficients = feature_coefficients # [L, 1] + tf.ensure_shape(self._feature_coefficients, tf.TensorShape([None, 1])) @property diff --git a/gpflux/feature_decomposition_kernels/multioutput/__init__.py b/gpflux/feature_decomposition_kernels/multioutput/__init__.py new file mode 100644 index 00000000..9b419a1c --- /dev/null +++ b/gpflux/feature_decomposition_kernels/multioutput/__init__.py @@ -0,0 +1,11 @@ +from .kernel_with_feature_decomposition import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, + _MultiOutputApproximateKernel, +) + +__all__ = [ + "_MultiOutputApproximateKernel", + "SharedMultiOutputKernelWithFeatureDecomposition", + "SeparateMultiOutputKernelWithFeatureDecomposition", +] diff --git a/gpflux/feature_decomposition_kernels/multioutput/kernel_with_feature_decomposition.py b/gpflux/feature_decomposition_kernels/multioutput/kernel_with_feature_decomposition.py new file mode 100644 index 00000000..56170c1d --- /dev/null +++ b/gpflux/feature_decomposition_kernels/multioutput/kernel_with_feature_decomposition.py @@ -0,0 +1,433 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +r""" +The classes in this module encapsulate kernels :math:`k(\cdot, \cdot)` with +their features :math:`\phi_i(\cdot)` and coefficients :math:`\lambda_i` so +that: + +.. math:: + + k(x, x') = \sum_{i=0}^\infty \lambda_i \phi_i(x) \phi_i(x'). + +The kernels are used for efficient sampling. See the tutorial notebooks +`Efficient sampling <../../../../notebooks/efficient_sampling.ipynb>`_ +and `Weight Space Approximation with Random Fourier Features +<../../../../notebooks/weight_space_approximation.ipynb>`_ +for an in-depth overview. +""" +from typing import Any, Optional, Tuple, Union + +import tensorflow as tf + +import gpflow +from gpflow.base import TensorType +from gpflow.kernels import SquaredExponential + +NoneType = type(None) + + +# TODO -- this needs to be a subclass of ApproximateKernel and maybe MultioutputKernel as well +# NOTE -- I think the MultiOutputKernel needs to be here for the dispatcher in the covariances +class _MultiOutputApproximateKernel(gpflow.kernels.MultioutputKernel): + r""" + #TODO -- update documentation to suit multioutput case + TODO: unless # [P, N, L] etc. are used in static analysis, we can probably unify this with + _ApproximateKernel + This class approximates a kernel by the finite feature decomposition: + + .. math:: k(x, x') = \sum_{i=0}^L \lambda_i \phi_i(x) \phi_i(x'), + + where :math:`\lambda_i` and :math:`\phi_i(\cdot)` are the coefficients + and features, respectively. + + This class deals with the case of multiple kernels producing multiple outputs. + They are computed together here for improved efficiency dervied from vectorizing the whole + lot together before putting it through tensorflow (improve this comment) + + """ + + def __init__( + self, + feature_functions: tf.keras.layers.Layer, + feature_coefficients: TensorType, + ): + r""" + :param feature_functions: A Keras layer for which the call evaluates the + ``L`` features of the kernel :math:`\phi_i(\cdot)`. For ``X`` with the shape ``[N, D]``, + ``feature_functions(X)`` returns a tensor with the shape ``[P, N, L]``. + :param feature_coefficients: A tensor with the shape ``[P, L, 1]`` with coefficients + associated with the features, :math:`\lambda_i`. + """ + self._feature_functions = feature_functions # [P, N, L] + self._feature_coefficients = feature_coefficients # [P, L, 1] + + # Concretises ABC methods from MultioutputKernel + @property + def num_latent_gps(self) -> int: + # In this case number of latent GPs (L) == output_dim (P) + return tf.shape(self._feature_coefficients)[0] + + # Concretises ABC methods from MultioutputKernel + @property + def latent_kernels(self) -> Any: + """In this scenario we do not have access to the underlying kernels, + so we are just returning the feature_functions""" + return self._feature_functions + + def K( + self, + X: TensorType, + X2: Optional[TensorType] = None, + full_output_cov: bool = True, + ) -> tf.Tensor: + """Approximate the true kernel by an inner product between feature functions.""" + phi = self._feature_functions(X) # [P, N, L] + + L = tf.shape(phi)[-1] + # NOTE - there are some differences between using FourierFeatures and FourierFeaturesCosine, + # extra check to ensure right shape and to guide debugging in notebooks + tf.debugging.assert_equal(tf.shape(self._feature_coefficients), [self.num_latent_gps, L, 1]) + + if X2 is None: + phi2 = phi + else: + phi2 = self._feature_functions(X2) # [P, N2, L] + + r = tf.linalg.matmul( + phi, + tf.linalg.matrix_transpose(self._feature_coefficients) * phi2, + transpose_b=True, + ) # [P, N, N2] + + N1, N2 = tf.shape(phi)[1], tf.shape(phi2)[1] + tf.debugging.assert_equal(tf.shape(r), [self.num_latent_gps, N1, N2]) + return r + + def K_diag(self, X: TensorType, full_output_cov: bool = False) -> tf.Tensor: + """Approximate the true kernel by an inner product between feature functions.""" + + phi_squared = self._feature_functions(X) ** 2 # [P, N, L] + r = tf.reduce_sum( + phi_squared * tf.linalg.matrix_transpose(self._feature_coefficients), + axis=-1, + ) # [P,N,] + N = tf.shape(X)[0] + + tf.debugging.assert_equal(tf.shape(r), [self.num_latent_gps, N]) # noqa: E231 + return r + + +# The difference between shared and separate case is the *kernel*. +# Shared: kernel is SharedIndependent +# Separate: kernel is SeparateIndependent. +# Use static typing (?) +# TODO: DRY +class SharedMultiOutputKernelWithFeatureDecompositionBase(gpflow.kernels.SharedIndependent): + + """ + 'Wrapper' class to solve the issue with full_cov:bool = False + inherited from gpflow.kernels.MultiOutputKernel + which doesn't work well with GPRPosterior, + as it does not use dispatchers from gpflow.covariances. + """ + + # Overriding __call__ from gpflow.kernels.MultioutputKernel + def __call__( + self, + X: TensorType, + X2: Optional[TensorType] = None, + *, + full_cov: bool = True, # NOTE -- otherwise will throw errors (possibly posterior related) + full_output_cov: bool = True, + presliced: bool = False, + ) -> tf.Tensor: + if not presliced: + X, X2 = self.slice(X, X2) + if not full_cov and X2 is not None: + raise ValueError( + "Ambiguous inputs: passing in `X2` is not compatible with `full_cov=False`." + ) + if not full_cov: + return self.K_diag(X, full_output_cov=full_output_cov) + return self.K(X, X2, full_output_cov=full_output_cov) + + +class SharedMultiOutputKernelWithFeatureDecomposition( + SharedMultiOutputKernelWithFeatureDecompositionBase +): + r""" + This class represents a gpflow.kernels.SharedIndependent kernel together + with its finite feature decomposition: + + .. math:: k(x, x') = \sum_{i=0}^L \lambda_i \phi_i(x) \phi_i(x'), + + where :math:`\lambda_i` and :math:`\phi_i(\cdot)` are the coefficients and + features, respectively. + + The decomposition can be derived from Mercer or Bochner's theorem. For example, + feature-coefficient pairs could be eigenfunction-eigenvalue pairs (Mercer) or + Fourier features with constant coefficients (Bochner). + + In some cases (e.g., [1]_ and [2]_) the left-hand side (that is, the + covariance function :math:`k(\cdot, \cdot)`) is unknown and the kernel + can only be approximated using its feature decomposition. + In other cases (e.g., [3]_ and [4]_), both the covariance function and feature + decomposition are available in closed form. + + .. [1] + Solin, Arno, and Simo Särkkä. "Hilbert space methods for + reduced-rank Gaussian process regression." Statistics and Computing + (2020). + .. [2] + Borovitskiy, Viacheslav, et al. "Matérn Gaussian processes on + Riemannian manifolds." In Advances in Neural Information Processing + Systems (2020). + .. [3] + Ali Rahimi and Benjamin Recht. Random features for large-scale kernel + machines. In Advances in Neural Information Processing Systems (2007). + .. [4] + Dutordoir, Vincent, Nicolas Durrande, and James Hensman. "Sparse + Gaussian processes with spherical harmonic features." In International + Conference on Machine Learning (2020). + """ + + def __init__( + self, + kernel: Union[gpflow.kernels.Kernel, NoneType], + feature_functions: tf.keras.layers.Layer, + feature_coefficients: TensorType, + *, + output_dim: Optional[int] = None, + ): + r""" + :param kernel: The kernel corresponding to the feature decomposition. + If ``None``, there is no analytical expression associated with the infinite + sum and we approximate the kernel based on the feature decomposition. + + .. note:: + + In certain cases, the analytical expression for the kernel is + not available. In this case, passing `None` is allowed, and + :meth:`K` and :meth:`K_diag` will be computed using the + approximation provided by the feature decomposition. + + :param feature_functions: A Keras layer for which the call evaluates the + ``L`` features of the kernel :math:`\phi_i(\cdot)`. For ``X`` with the shape ``[N, D]``, + ``feature_functions(X)`` returns a tensor with the shape ``[P, N, L]``. + :param feature_coefficients: A tensor with the shape ``[P, L, 1]`` with coefficients + associated with the features, :math:`\lambda_i`. + #TODO -- add output_dim + """ + + if kernel is None: + # NOTE -- this is a subclass of gpflow.kernels.SharedIndependent + # (needed to be used with dispatchers from gpflow.covariances) + # so it needs to be initialized somehow. + # TODO -- Not sure if most efficient way + _dummy_kernel = SquaredExponential() + super().__init__(_dummy_kernel, output_dim) + self._kernel = _MultiOutputApproximateKernel(feature_functions, feature_coefficients) + else: + super().__init__(kernel.kernel, kernel.output_dim) + self._kernel = kernel + + self._feature_functions = feature_functions # [P, N, L] + self._feature_coefficients = feature_coefficients # [P, L, 1] + + tf.ensure_shape(self._feature_coefficients, tf.TensorShape([None, None, 1])) + + @property + def num_latent_gps(self) -> int: + return self.output_dim + + @property + def latent_kernels(self) -> Tuple[gpflow.kernels.Kernel, ...]: + """The underlying kernels in the multioutput kernel""" + if isinstance(self._kernel, _MultiOutputApproximateKernel): + return ( + self._kernel.latent_kernels + ) # NOTE -- this will return self._feature_functions from ApproximateKernel + else: + return (self._kernel,) + + @property + def feature_functions(self) -> tf.keras.layers.Layer: + r"""Return the kernel's features :math:`\phi_i(\cdot)`.""" + return self._feature_functions + + @property + def feature_coefficients(self) -> tf.Tensor: + r"""Return the kernel's coefficients :math:`\lambda_i`.""" + return self._feature_coefficients + + def K( + self, + X: TensorType, + X2: Optional[TensorType] = None, + full_output_cov: bool = True, + ) -> tf.Tensor: + return self._kernel.K(X, X2, full_output_cov) + + def K_diag(self, X: TensorType, full_output_cov: bool = True) -> tf.Tensor: + return self._kernel.K_diag(X, full_output_cov) + + +# NOTE -- this is the same as the Shared case above +class SeparateMultiOutputKernelWithFeatureDecompositionBase(gpflow.kernels.SeparateIndependent): + + """ + 'Wrapper' class to solve the issue with full_cov:bool = False + inherited from gpflow.kernels.MultiOutputKernel + which doesn't work well with GPRPosterior, + as it does not use dispatchers from gpflow.covariances. + """ + + # Overriding __call__ from gpflow.kernels.MultioutputKernel + def __call__( + self, + X: TensorType, + X2: Optional[TensorType] = None, + *, + full_cov: bool = True, # NOTE -- otherwise will throw errors, might be problematic here + full_output_cov: bool = True, + presliced: bool = False, + ) -> tf.Tensor: + if not presliced: + X, X2 = self.slice(X, X2) + if not full_cov and X2 is not None: + raise ValueError( + "Ambiguous inputs: passing in `X2` is not compatible with `full_cov=False`." + ) + if not full_cov: + return self.K_diag(X, full_output_cov=full_output_cov) + return self.K(X, X2, full_output_cov=full_output_cov) + + +# NOTE -- this is the same as the Shared case above +class SeparateMultiOutputKernelWithFeatureDecomposition( + SeparateMultiOutputKernelWithFeatureDecompositionBase +): + r""" + This class represents a gpflow.kernel.SeparateIndependent + together with its finite feature decomposition: + + .. math:: k(x, x') = \sum_{i=0}^L \lambda_i \phi_i(x) \phi_i(x'), + + where :math:`\lambda_i` and :math:`\phi_i(\cdot)` are the coefficients and + features, respectively. + + The decomposition can be derived from Mercer or Bochner's theorem. For example, + feature-coefficient pairs could be eigenfunction-eigenvalue pairs (Mercer) or + Fourier features with constant coefficients (Bochner). + + In some cases (e.g., [1]_ and [2]_) the left-hand side (that is, the + covariance function :math:`k(\cdot, \cdot)`) is unknown and the kernel + can only be approximated using its feature decomposition. + In other cases (e.g., [3]_ and [4]_), both the covariance function and feature + decomposition are available in closed form. + + .. [1] + Solin, Arno, and Simo Särkkä. "Hilbert space methods for + reduced-rank Gaussian process regression." Statistics and Computing + (2020). + .. [2] + Borovitskiy, Viacheslav, et al. "Matérn Gaussian processes on + Riemannian manifolds." In Advances in Neural Information Processing + Systems (2020). + .. [3] + Ali Rahimi and Benjamin Recht. Random features for large-scale kernel + machines. In Advances in Neural Information Processing Systems (2007). + .. [4] + Dutordoir, Vincent, Nicolas Durrande, and James Hensman. "Sparse + Gaussian processes with spherical harmonic features." In International + Conference on Machine Learning (2020). + """ + + def __init__( + self, + kernel: Union[gpflow.kernels.Kernel, NoneType], + feature_functions: tf.keras.layers.Layer, + feature_coefficients: TensorType, + *, + output_dim: Optional[int] = None, + ): + r""" + :param kernel: The kernel corresponding to the feature decomposition. + If ``None``, there is no analytical expression associated with the infinite + sum and we approximate the kernel based on the feature decomposition. + + .. note:: + + In certain cases, the analytical expression for the kernel is + not available. In this case, passing `None` is allowed, and + :meth:`K` and :meth:`K_diag` will be computed using thekernel + approximation provided by the feature decomposition. + + :param feature_functions: A Keras layer for which the call evaluates the + ``L`` features of the kernel :math:`\phi_i(\cdot)`. For ``X`` with the shape ``[N, D]``, + ``feature_functions(X)`` returns a tensor with the shape ``[P, N, L]``. + :param feature_coefficients: A tensor with the shape ``[P, L, 1]`` with coefficients + associated with the features, :math:`\lambda_i`. + """ + + if kernel is None: + # NOTE -- this is a subclass of gpflow.kernels.SeparateIndependent + # (needed to be used with dispatchers from gpflow.covariances) + # so it needs to be initialized somehow. Not sure if efficient + # TODO -- this is dodgy, needs smarter solution + _dummy_kernels = [SquaredExponential() for _ in range(output_dim)] + super().__init__(_dummy_kernels) + self._kernel = _MultiOutputApproximateKernel(feature_functions, feature_coefficients) + else: + + super().__init__(kernel.kernels) + self._kernel = kernel + + self._feature_functions = feature_functions # [P, N, L] + self._feature_coefficients = feature_coefficients # [P, L, 1] + + tf.ensure_shape(self._feature_coefficients, tf.TensorShape([None, None, 1])) + + @property + def num_latent_gps(self) -> int: + + return len(self._kernel.kernels) + + @property + def latent_kernels(self) -> Tuple[gpflow.kernels.Kernel, ...]: + """The underlying kernels in the multioutput kernel""" + return tuple(self._kernel.kernels) + + @property + def feature_functions(self) -> tf.keras.layers.Layer: + r"""Return the kernel's features :math:`\phi_i(\cdot)`.""" + return self._feature_functions + + @property + def feature_coefficients(self) -> tf.Tensor: + r"""Return the kernel's coefficients :math:`\lambda_i`.""" + return self._feature_coefficients + + def K( + self, + X: TensorType, + X2: Optional[TensorType] = None, + full_output_cov: bool = True, + ) -> tf.Tensor: + return self._kernel.K(X, X2, full_output_cov) + + def K_diag(self, X: TensorType, full_output_cov: bool = False) -> tf.Tensor: + return self._kernel.K_diag(X, full_output_cov) diff --git a/gpflux/layers/basis_functions/fourier_features/__init__.py b/gpflux/layers/basis_functions/fourier_features/__init__.py index 42c09d47..064fc39a 100644 --- a/gpflux/layers/basis_functions/fourier_features/__init__.py +++ b/gpflux/layers/basis_functions/fourier_features/__init__.py @@ -18,6 +18,10 @@ :class:`gpflux.sampling.KernelWithFeatureDecomposition` """ +from gpflux.layers.basis_functions.fourier_features.multioutput.random import ( + MultiOutputRandomFourierFeatures, + MultiOutputRandomFourierFeaturesCosine, +) from gpflux.layers.basis_functions.fourier_features.quadrature import QuadratureFourierFeatures from gpflux.layers.basis_functions.fourier_features.random import ( OrthogonalRandomFeatures, @@ -30,4 +34,6 @@ "OrthogonalRandomFeatures", "RandomFourierFeatures", "RandomFourierFeaturesCosine", + "MultiOutputRandomFourierFeatures", + "MultiOutputRandomFourierFeaturesCosine", ] diff --git a/gpflux/layers/basis_functions/fourier_features/base.py b/gpflux/layers/basis_functions/fourier_features/base.py index 80461c80..6282606b 100644 --- a/gpflux/layers/basis_functions/fourier_features/base.py +++ b/gpflux/layers/basis_functions/fourier_features/base.py @@ -22,7 +22,6 @@ import gpflow from gpflow.base import TensorType - from gpflux.types import ShapeType diff --git a/gpflux/layers/basis_functions/fourier_features/multioutput/__init__.py b/gpflux/layers/basis_functions/fourier_features/multioutput/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gpflux/layers/basis_functions/fourier_features/multioutput/base.py b/gpflux/layers/basis_functions/fourier_features/multioutput/base.py new file mode 100644 index 00000000..b707de82 --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/multioutput/base.py @@ -0,0 +1,148 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" Shared functionality for stationary kernel basis functions. """ + +from abc import ABC, abstractmethod +from typing import Mapping + +import tensorflow as tf + +import gpflow +from gpflow.base import TensorType +from gpflux.types import ShapeType + + +class MultiOutputFourierFeaturesBase(ABC, tf.keras.layers.Layer): + def __init__( + self, kernel: gpflow.kernels.MultioutputKernel, n_components: int, **kwargs: Mapping + ): + """ + :param kernel: kernel to approximate using a set of Fourier bases. + Expects a Multioutput Kernel + :param n_components: number of components (e.g. Monte Carlo samples, + quadrature nodes, etc.) used to numerically approximate the kernel. + """ + super(MultiOutputFourierFeaturesBase, self).__init__(**kwargs) + # NOTE -- same as univariate case from here till the end of __init__ + self.kernel = kernel + self.n_components = n_components + if kwargs.get("input_dim", None): + self._input_dim = kwargs["input_dim"] + self.build(tf.TensorShape([self._input_dim])) + else: + self._input_dim = None + + def call(self, inputs: TensorType) -> tf.Tensor: + """ + Evaluate the basis functions at ``inputs``. + + :param inputs: The evaluation points, a tensor with the shape ``[N, D]``. + :return: A tensor with the shape ``[P, N, M]``.mypy + """ + P = self.kernel.num_latent_gps + D = tf.shape(inputs)[-1] + + if isinstance(self.kernel, gpflow.kernels.SeparateIndependent): + + for kernel in self.kernel.kernels: + print(kernel.lengthscales.unconstrained_variable.value()) + + _lengthscales = tf.concat( + [ + kernel.lengthscales[None, None, ...] + if tf.rank(kernel.lengthscales.unconstrained_variable.value()) == 1 + else kernel.lengthscales[None, None, None, ...] + for kernel in self.kernel.kernels + ], + axis=0, + ) # [P, 1, D] + tf.debugging.assert_equal(tf.shape(_lengthscales), [P, 1, D]) + + elif isinstance(self.kernel, gpflow.kernels.SharedIndependent): + # NOTE -- each kernel.kernel.lengthscales has to be of the shape [D,] + _lengthscales = tf.tile( + self.kernel.kernel.lengthscales[None, None, ...] + if tf.rank(self.kernel.kernel.lengthscales.unconstrained_variable.value()) == 1 + else self.kernel.kernel.lengthscales[None, None, None, ...], + [P, 1, 1], + ) # [P, 1, D] + tf.debugging.assert_equal(tf.shape(_lengthscales), [P, 1, D]) + else: + raise ValueError("kernel is not supported.") + + X = tf.divide( + inputs, # [N, D] or [P, M, D] + _lengthscales, # [P, 1, D] + ) # [P, N, D] or [P, M, D] + + const = self._compute_constant()[..., None, None] # [P,1,1] + bases = self._compute_bases(X) # [P, N, L] for X*, or [P,M,L] in the case of Z + output = const * bases # [P, N, L] for X*, or [P,M,L] in the case of Z + + tf.ensure_shape(output, self.compute_output_shape(X.shape)) + return output + + def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape: + """ + Computes the output shape of the layer. + See `tf.keras.layers.Layer.compute_output_shape() + `_. + """ + # TODO: Keras docs say "If the layer has not been built, this method + # will call `build` on the layer." -- do we need to do so? + + tensor_shape = tf.TensorShape(input_shape).with_rank(3) + output_dim = self._compute_output_dim(input_shape) + return tensor_shape[:-1].concatenate(output_dim) + + # NOTE -- same as univariate case + def get_config(self) -> Mapping: + """ + Returns the config of the layer. + See `tf.keras.layers.Layer.get_config() + `_. + """ + config = super(MultiOutputFourierFeaturesBase, self).get_config() + config.update( + { + "kernel": self.kernel, + "n_components": self.n_components, + "input_dim": self._input_dim, + } + ) + + return config + + # NOTE -- same as univariate case + @abstractmethod + def _compute_output_dim(self, input_shape: ShapeType) -> int: + pass + + # NOTE -- same as univariate case + @abstractmethod + def _compute_constant(self) -> tf.Tensor: + """ + Compute normalizing constant for basis functions. + """ + pass + + # NOTE -- same as univariate case + @abstractmethod + def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + """ + Compute basis functions. + """ + pass diff --git a/gpflux/layers/basis_functions/fourier_features/multioutput/random/__init__.py b/gpflux/layers/basis_functions/fourier_features/multioutput/random/__init__.py new file mode 100644 index 00000000..37be97ef --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/multioutput/random/__init__.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" A kernel's features and coefficients using Random Fourier Features (RFF). """ + +from gpflux.layers.basis_functions.fourier_features.multioutput.random.base import ( + MultiOutputRandomFourierFeatures, + MultiOutputRandomFourierFeaturesCosine, +) +from gpflux.layers.basis_functions.fourier_features.multioutput.random.orthogonal import ( + MultiOutputOrthogonalRandomFeatures, +) + +__all__ = [ + "MultiOutputRandomFourierFeatures", + "MultiOutputRandomFourierFeaturesCosine", + "MultiOutputOrthogonalRandomFeatures", +] diff --git a/gpflux/layers/basis_functions/fourier_features/multioutput/random/base.py b/gpflux/layers/basis_functions/fourier_features/multioutput/random/base.py new file mode 100644 index 00000000..4959f7b5 --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/multioutput/random/base.py @@ -0,0 +1,306 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Mapping, Optional, Tuple, Type + +import numpy as np +import tensorflow as tf + +import gpflow +from gpflow.base import DType, TensorType + +from gpflux.layers.basis_functions.fourier_features.multioutput.base import ( + MultiOutputFourierFeaturesBase, +) +from gpflux.layers.basis_functions.fourier_features.utils import ( + _bases_concat, + _bases_cosine, + _matern_number, +) +from gpflux.types import ShapeType + +""" +Kernels supported by :class:`RandomFourierFeatures`. + +You can build RFF for shift-invariant stationary kernels from which you can +sample frequencies from their power spectrum, following Bochner's theorem. +""" +RFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( + gpflow.kernels.SquaredExponential, + gpflow.kernels.Matern12, + gpflow.kernels.Matern32, + gpflow.kernels.Matern52, +) + + +# TODO -- import this from the univariate folder +def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType: + """ + Draw samples from a (central) Student's t-distribution using the following: + BETA ~ Gamma(nu/2, nu/2) (shape-rate parameterization) + X ~ Normal(0, 1/BETA) + then: + X ~ StudentsT(nu) + + Note this is equivalent to the more commonly used parameterization + Z ~ Chi2(nu) = Gamma(nu/2, 1/2) + EPSILON ~ Normal(0, 1) + X = EPSILON * sqrt(nu/Z) + + To see this, note + Z/nu ~ Gamma(nu/2, nu/2) + and + X ~ Normal(0, nu/Z) + The equivalence becomes obvious when we set BETA = Z/nu + """ + # Normal(0, 1) + normal_rvs = tf.random.normal(shape=shape, dtype=dtype) + shape = tf.concat([shape[:-1], [1]], axis=0) + # Gamma(nu/2, nu/2) + gamma_rvs = tf.random.gamma(shape, alpha=0.5 * nu, beta=0.5 * nu, dtype=dtype) + # StudentsT(nu) + students_t_rvs = tf.math.rsqrt(gamma_rvs) * normal_rvs + return students_t_rvs + + +class MultiOutputRandomFourierFeaturesBase(MultiOutputFourierFeaturesBase): + def __init__( + self, kernel: gpflow.kernels.MultioutputKernel, n_components: int, **kwargs: Mapping + ): + + if isinstance(kernel, gpflow.kernels.SeparateIndependent): + for ker in kernel.kernels: + assert isinstance(ker, RFF_SUPPORTED_KERNELS), "Unsupported Kernel" + elif isinstance(kernel, gpflow.kernels.SharedIndependent): + assert isinstance(kernel.kernel, RFF_SUPPORTED_KERNELS), "Unsupported Kernel" + else: + raise ValueError("kernel specified is not supported.") + super(MultiOutputRandomFourierFeaturesBase, self).__init__(kernel, n_components, **kwargs) + + # NOTE -- same as univariate case + def build(self, input_shape: ShapeType) -> None: + """ + Creates the variables of the layer. + See `tf.keras.layers.Layer.build() + `_. + """ + input_dim = input_shape[-1] + self._weights_build(input_dim, n_components=self.n_components) + super(MultiOutputRandomFourierFeaturesBase, self).build(input_shape) + + def _weights_build(self, input_dim: int, n_components: int) -> None: + + shape = (self.kernel.num_latent_gps, n_components, input_dim) + + self.W = self.add_weight( + name="weights", + trainable=False, + shape=shape, + dtype=self.dtype, + initializer=self._weights_init, + ) + tf.ensure_shape(self.W, shape) + + def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: + + if isinstance(self.kernel, gpflow.kernels.SeparateIndependent): + + list_inits = [] + for ker in self.kernel.kernels: + if isinstance(ker, gpflow.kernels.SquaredExponential): + list_inits.append(tf.random.normal(shape[1:], dtype=dtype)) + else: + p = _matern_number(ker) + nu = 2.0 * p + 1.0 # degrees of freedom + list_inits.append(_sample_students_t(nu, shape[1:], dtype)) + return tf.stack(list_inits, axis=0) + + elif isinstance(self.kernel, gpflow.kernels.SharedIndependent): + + if isinstance(self.kernel.kernel, gpflow.kernels.SquaredExponential): + return tf.random.normal(shape, dtype=dtype) + else: + p = _matern_number(self.kernel.kernel) + nu = 2.0 * p + 1.0 # degrees of freedom + return _sample_students_t(nu, shape, dtype) + else: + raise ValueError("kernel is not supported.") + + # NOTE -- same as univariate case + @staticmethod + def rff_constant(variance: TensorType, output_dim: int) -> tf.Tensor: + """ + Normalizing constant for random Fourier features. + """ + return tf.sqrt(tf.math.truediv(2.0 * variance, output_dim)) + + +class MultiOutputRandomFourierFeatures(MultiOutputRandomFourierFeaturesBase): + r""" + #TODO -- update documentation to suit multioutput case + Random Fourier features (RFF) is a method for approximating kernels. The essential + element of the RFF approach :cite:p:`rahimi2007random` is the realization that Bochner's theorem + for stationary kernels can be approximated by a Monte Carlo sum. + + We will approximate the kernel :math:`k(\mathbf{x}, \mathbf{x}')` + by :math:`\Phi(\mathbf{x})^\top \Phi(\mathbf{x}')` + where :math:`\Phi: \mathbb{R}^{D} \to \mathbb{R}^{M}` is a finite-dimensional feature map. + + The feature map is defined as: + + .. math:: + + \Phi(\mathbf{x}) = \sqrt{\frac{2 \sigma^2}{\ell}} + \begin{bmatrix} + \cos(\boldsymbol{\theta}_1^\top \mathbf{x}) \\ + \sin(\boldsymbol{\theta}_1^\top \mathbf{x}) \\ + \vdots \\ + \cos(\boldsymbol{\theta}_{\frac{M}{2}}^\top \mathbf{x}) \\ + \sin(\boldsymbol{\theta}_{\frac{M}{2}}^\top \mathbf{x}) + \end{bmatrix} + + where :math:`\sigma^2` is the kernel variance. + The features are parameterised by random weights: + + - :math:`\boldsymbol{\theta} \sim p(\boldsymbol{\theta})` + where :math:`p(\boldsymbol{\theta})` is the spectral density of the kernel. + + At least for the squared exponential kernel, this variant of the feature + mapping has more desirable theoretical properties than its counterpart form + from phase-shifted cosines :class:`RandomFourierFeaturesCosine` :cite:p:`sutherland2015error`. + """ + + # NOTE -- same as univariate case + def _compute_output_dim(self, input_shape: ShapeType) -> int: + return 2 * self.n_components + + def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + """ + Compute basis functions. + + :return: A tensor with the shape ``[P, N, 2M]`` . + """ + return _bases_concat(inputs, self.W) + + def _compute_constant(self) -> tf.Tensor: + """ + Compute normalizing constant for basis functions. + + :return: A tensor with the shape ``[]`` (i.e. a scalar). + """ + + if hasattr(self.kernel, "kernels"): + _kernel_variance = tf.stack([ker.variance for ker in self.kernel.kernels], axis=0) + tf.ensure_shape( + _kernel_variance, + [ + self.kernel.num_latent_gps, + ], + ) + + else: + _kernel_variance = self.kernel.kernel.variance + + return self.rff_constant(_kernel_variance, output_dim=2 * self.n_components) + + +class MultiOutputRandomFourierFeaturesCosine(MultiOutputRandomFourierFeaturesBase): + r""" + #TODO -- update documentation to suit multioutput case + Random Fourier Features (RFF) is a method for approximating kernels. The essential + element of the RFF approach :cite:p:`rahimi2007random` is the realization that Bochner's theorem + for stationary kernels can be approximated by a Monte Carlo sum. + + We will approximate the kernel :math:`k(\mathbf{x}, \mathbf{x}')` + by :math:`\Phi(\mathbf{x})^\top \Phi(\mathbf{x}')` where + :math:`\Phi: \mathbb{R}^{D} \to \mathbb{R}^{M}` is a finite-dimensional feature map. + + The feature map is defined as: + + .. math:: + \Phi(\mathbf{x}) = \sqrt{\frac{2 \sigma^2}{\ell}} + \begin{bmatrix} + \cos(\boldsymbol{\theta}_1^\top \mathbf{x} + \tau) \\ + \vdots \\ + \cos(\boldsymbol{\theta}_M^\top \mathbf{x} + \tau) + \end{bmatrix} + + where :math:`\sigma^2` is the kernel variance. + The features are parameterised by random weights: + + - :math:`\boldsymbol{\theta} \sim p(\boldsymbol{\theta})` + where :math:`p(\boldsymbol{\theta})` is the spectral density of the kernel + - :math:`\tau \sim \mathcal{U}(0, 2\pi)` + + Equivalent to :class:`RandomFourierFeatures` by elementary trigonometric identities. + """ + + def build(self, input_shape: ShapeType) -> None: + """ + Creates the variables of the layer. + See `tf.keras.layers.Layer.build() + `_. + """ + self._bias_build(n_components=self.n_components) + super(MultiOutputRandomFourierFeaturesCosine, self).build(input_shape) + + def _bias_build(self, n_components: int) -> None: + + shape = (self.kernel.num_latent_gps, 1, n_components) + + self.b = self.add_weight( + name="bias", + trainable=False, + shape=shape, + dtype=self.dtype, + initializer=self._bias_init, + ) + + # NOTE -- same as univariate case + def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: + return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype) + + # NOTE -- same as univariate case + def _compute_output_dim(self, input_shape: ShapeType) -> int: + return self.n_components + + # NOTE -- same as univariate case + def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + """ + Compute basis functions. + + :return: A tensor with the shape ``[N, M]``. + """ + return _bases_cosine(inputs, self.W, self.b) + + def _compute_constant(self) -> tf.Tensor: + """ + Compute normalizing constant for basis functions. + + :return: A tensor with the shape ``[]`` (i.e. a scalar). + """ + + if hasattr(self.kernel, "kernels"): + _kernel_variance = tf.stack([ker.variance for ker in self.kernel.kernels], axis=0) + tf.ensure_shape( + _kernel_variance, + [ + self.kernel.num_latent_gps, + ], + ) + else: + _kernel_variance = self.kernel.kernel.variance + + return self.rff_constant(_kernel_variance, output_dim=self.n_components) diff --git a/gpflux/layers/basis_functions/fourier_features/multioutput/random/orthogonal.py b/gpflux/layers/basis_functions/fourier_features/multioutput/random/orthogonal.py new file mode 100644 index 00000000..c5ae8453 --- /dev/null +++ b/gpflux/layers/basis_functions/fourier_features/multioutput/random/orthogonal.py @@ -0,0 +1,99 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Mapping, Optional, Tuple, Type + +import numpy as np +import tensorflow as tf + +import gpflow +from gpflow.base import DType, TensorType + +from gpflux.layers.basis_functions.fourier_features.multioutput.random.base import ( + MultiOutputRandomFourierFeatures, +) +from gpflux.types import ShapeType + +""" +Kernels supported by :class:`OrthogonalRandomFeatures`. + +This random matrix sampling scheme only applies to the :class:`gpflow.kernels.SquaredExponential` +kernel. +For Matern kernels please use :class:`RandomFourierFeatures` +or :class:`RandomFourierFeaturesCosine`. +""" +ORF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = ( + gpflow.kernels.SquaredExponential, +) + + +def _sample_chi_squared(nu: float, shape: ShapeType, dtype: DType) -> TensorType: + """ + Draw samples from Chi-squared distribution with `nu` degrees of freedom. + + See https://mathworld.wolfram.com/Chi-SquaredDistribution.html for further + details regarding relationship to Gamma distribution. + """ + return tf.random.gamma(shape=shape, alpha=0.5 * nu, beta=0.5, dtype=dtype) + + +def _sample_chi(nu: float, shape: ShapeType, dtype: DType) -> TensorType: + """ + Draw samples from Chi-distribution with `nu` degrees of freedom. + """ + s = _sample_chi_squared(nu, shape, dtype) + return tf.sqrt(s) + + +def _ceil_divide(a: float, b: float) -> int: + """ + Ceiling division. Returns the smallest integer `m` s.t. `m*b >= a`. + """ + return -np.floor_divide(-a, b) + + +class MultiOutputOrthogonalRandomFeatures(MultiOutputRandomFourierFeatures): + r""" + Orthogonal random Fourier features (ORF) :cite:p:`yu2016orthogonal` for more + efficient and accurate kernel approximations than :class:`RandomFourierFeatures`. + """ + + def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): + + if isinstance(kernel, gpflow.kernels.SeparateIndependent): + for ker in kernel.kernels: + assert isinstance(ker, ORF_SUPPORTED_KERNELS), "Unsupported Kernel" + elif isinstance(kernel, gpflow.kernels.SharedIndependent): + assert isinstance(kernel.kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel" + else: + raise ValueError("kernel specified is not supported.") + + super(MultiOutputOrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs) + + def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: + n_out, n_components, input_dim = shape # P, M, D + n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M + + W = tf.random.normal(shape=(n_out, n_reps, input_dim, input_dim), dtype=dtype) + Q, _ = tf.linalg.qr(W) # throw away R; shape [P, K, D, D] + + s = _sample_chi( + nu=input_dim, shape=(n_out, n_reps, input_dim), dtype=dtype + ) # shape [P, K, D] + U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [P, K, D, D] + V = tf.reshape(U, shape=(n_out, -1, input_dim)) # shape [P, K*D, D] + + return V[:, : self.n_components, :] # shape [P, M, D] (throw away K*D - M rows) diff --git a/gpflux/layers/basis_functions/fourier_features/utils.py b/gpflux/layers/basis_functions/fourier_features/utils.py index 71372e9e..fc4ee205 100644 --- a/gpflux/layers/basis_functions/fourier_features/utils.py +++ b/gpflux/layers/basis_functions/fourier_features/utils.py @@ -40,8 +40,8 @@ def _bases_cosine(X: TensorType, W: TensorType, b: TensorType) -> TensorType: by Rahimi & Recht, 2007 :cite:p:`rahimi2007random`. See also :cite:p:`sutherland2015error` for additional details. """ - proj = tf.matmul(X, W, transpose_b=True) + b # [N, M] - return tf.cos(proj) # [N, M] + proj = tf.matmul(X, W, transpose_b=True) + b # [N, M] or [P, N, M] + return tf.cos(proj) # [N, M] or [P, N, M] def _bases_concat(X: TensorType, W: TensorType) -> TensorType: @@ -50,5 +50,5 @@ def _bases_concat(X: TensorType, W: TensorType) -> TensorType: by Rahimi & Recht, 2007 :cite:p:`rahimi2007random`. See also :cite:p:`sutherland2015error` for additional details. """ - proj = tf.matmul(X, W, transpose_b=True) # [N, M] - return tf.concat([tf.sin(proj), tf.cos(proj)], axis=-1) # [N, 2M] + proj = tf.matmul(X, W, transpose_b=True) # [N, M] or [P, N, M] + return tf.concat([tf.sin(proj), tf.cos(proj)], axis=-1) # [N, 2M] or [P, N, 2M] diff --git a/gpflux/sampling/__init__.py b/gpflux/sampling/__init__.py index 2c19c792..4c082df2 100644 --- a/gpflux/sampling/__init__.py +++ b/gpflux/sampling/__init__.py @@ -16,5 +16,12 @@ """ This module enables you to sample from (Deep) GPs efficiently and consistently. """ -from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition -from gpflux.sampling.sample import efficient_sample + +from . import multioutput, sample +from .dispatch import efficient_sample + +__all__ = [ + "efficient_sample", + "sample", + "multioutput", +] diff --git a/gpflux/sampling/dispatch.py b/gpflux/sampling/dispatch.py new file mode 100644 index 00000000..02146ead --- /dev/null +++ b/gpflux/sampling/dispatch.py @@ -0,0 +1,3 @@ +from gpflow.utilities import Dispatcher + +efficient_sample = Dispatcher("efficient_sample") diff --git a/gpflux/sampling/multioutput/__init__.py b/gpflux/sampling/multioutput/__init__.py new file mode 100644 index 00000000..8cf871f2 --- /dev/null +++ b/gpflux/sampling/multioutput/__init__.py @@ -0,0 +1,5 @@ +from . import sample + +__all__ = [ + "sample", +] diff --git a/gpflux/sampling/multioutput/sample.py b/gpflux/sampling/multioutput/sample.py new file mode 100644 index 00000000..3cc1bab5 --- /dev/null +++ b/gpflux/sampling/multioutput/sample.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" This module enables you to sample from (Deep) GPs using different approaches. """ + +from typing import Optional, Union + +import tensorflow as tf + +from gpflow.base import TensorType +from gpflow.config import default_float, default_jitter +from gpflow.covariances import Kuf, Kuu +from gpflow.inducing_variables import ( + MultioutputInducingVariables, + SeparateIndependentInducingVariables, + SharedIndependentInducingVariables, +) +from gpflow.kernels import SeparateIndependent, SharedIndependent + +from gpflux.feature_decomposition_kernels import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, +) +from gpflux.math import compute_A_inv_b + +from ..dispatch import efficient_sample +from ..sample import Sample + + +@efficient_sample.register( + MultioutputInducingVariables, + ( + SharedMultiOutputKernelWithFeatureDecomposition, + SeparateMultiOutputKernelWithFeatureDecomposition, + ), + object, +) +def _efficient_multi_output_sample_matheron_rule( + inducing_variable: MultioutputInducingVariables, + kernel: Union[ + SharedMultiOutputKernelWithFeatureDecomposition, + SeparateMultiOutputKernelWithFeatureDecomposition, + ], + q_mu: tf.Tensor, + *, + q_sqrt: Optional[TensorType] = None, + whiten: bool = False, +) -> Sample: + """ + Implements the efficient sampling rule from :cite:t:`wilson2020efficiently` using + the Matheron rule. To use this sampling scheme, the GP has to have a + ``kernel`` of the :class:`KernelWithFeatureDecomposition` type . + + :param kernel: A kernel of the :class:`KernelWithFeatureDecomposition` type, which + holds the covariance function and the kernel's features and + coefficients. + :param q_mu: A tensor with the shape ``[M, P]``. + :param q_sqrt: A tensor with the shape ``[P, M, M]``. + :param whiten: Determines the parameterisation of the inducing variables. + """ + + # TODO -- this makes sense just in the case of SeparateMultiOutputKernelWithFeatureDecomposition + # NOTE -- in the case of SharedMdultiOutputKernelWithFeatureDecomposition we can just rely on + # gpflux.sampliung.sample._efficient_sample_matheron_rule + + # Reshape kernel.feature_coefficients + _feature_coefficients = tf.transpose(kernel.feature_coefficients[..., 0]) # [L,P] + + L = tf.shape(_feature_coefficients)[0] # num eigenfunctions # noqa: F841 + M, P = tf.shape(q_mu)[0], tf.shape(q_mu)[1] # num inducing, num output heads + + prior_weights = tf.sqrt(_feature_coefficients) * tf.random.normal( + (L, P), dtype=default_float() # [L, P], [L,P] + ) # [L, P] + + u_sample_noise = tf.matmul( + q_sqrt, + tf.random.normal((P, M, 1), dtype=default_float()), # [P, M, M] # [P, M, 1] + ) # [P, M, 1] + tf.debugging.assert_equal(tf.shape(u_sample_noise), [P, M, 1]) + + if isinstance(kernel, SharedIndependent): + Kmm = tf.tile( + Kuu(inducing_variable, kernel, jitter=default_jitter())[None, ...], + [P, 1, 1], + ) # [P,M,M] + tf.debugging.assert_equal(tf.shape(Kmm), [P, M, M]) + elif isinstance(kernel, SeparateIndependent): + Kmm = Kuu(inducing_variable, kernel, jitter=default_jitter()) # [P,M,M] + tf.debugging.assert_equal(tf.shape(Kmm), [P, M, M]) + else: + raise ValueError( + "kernel not supported. Must be either SharedIndependent or SeparateIndependent" + ) + + tf.debugging.assert_equal(tf.shape(Kmm), [P, M, M]) + + u_sample = q_mu + tf.linalg.matrix_transpose(u_sample_noise[..., 0]) # [M, P] + tf.debugging.assert_equal(tf.shape(u_sample), [M, P]) + + if whiten: + Luu = tf.linalg.cholesky(Kmm) # [P,M,M] + tf.debugging.assert_equal(tf.shape(Kmm), [P, M, M]) + + u_sample = tf.transpose( + tf.matmul(Luu, tf.transpose(u_sample)[..., None])[..., 0] # [P, M, M] # [P, M, 1] + ) # [M, P] + tf.debugging.assert_equal(tf.shape(u_sample), [M, P]) + + if isinstance(inducing_variable, SeparateIndependentInducingVariables): + + _inducing_variable_list = [] + for ind_var in inducing_variable.inducing_variable_list: + _inducing_variable_list.append(ind_var.Z) + _inducing_variable_list = tf.stack(_inducing_variable_list, axis=0) + + phi_Z = kernel.feature_functions(_inducing_variable_list) # [P, M, L] + tf.debugging.assert_equal(tf.shape(phi_Z), [P, M, L]) + + elif isinstance(inducing_variable, SharedIndependentInducingVariables): + + phi_Z = kernel.feature_functions(inducing_variable.inducing_variable.Z) # [P, M, L] + tf.debugging.assert_equal(tf.shape(phi_Z), [P, M, L]) + else: + raise ValueError("inducing variable is not supported.") + + weight_space_prior_Z = tf.matmul(phi_Z, tf.transpose(prior_weights)[..., None]) # [P, M, 1] + weight_space_prior_Z = tf.transpose(weight_space_prior_Z[..., 0]) # [M, P] + + diff = tf.transpose(u_sample - weight_space_prior_Z)[..., None] # [P, M, 1] + v = tf.transpose(compute_A_inv_b(Kmm, diff)[..., 0]) # [P, M, M] # [P, M, 1] # [M, P] + + tf.debugging.assert_equal(tf.shape(v), [M, P]) + + class WilsonSample(Sample): + def __call__(self, X: TensorType) -> tf.Tensor: + """ + :param X: evaluation points [N, D] + :return: function value of sample [N, P] + """ + N = tf.shape(X)[0] + phi_X = kernel.feature_functions(X) # [P, N, L] + + weight_space_prior_X = tf.transpose( + tf.matmul(phi_X, tf.transpose(prior_weights)[..., None],)[ # [P, N, L] # [P, L, 1] + ..., 0 + ] + ) # [N, P] + + Knm = tf.linalg.matrix_transpose( + Kuf(inducing_variable, kernel, X) + ) # [P, N, M] or [N,M] + if isinstance(inducing_variable, SharedIndependentInducingVariables): + Knm = tf.tile(Knm[None, ...], [P, 1, 1]) + tf.debugging.assert_equal(tf.shape(Knm), [P, N, M]) + function_space_update_X = tf.transpose( + tf.matmul(Knm, tf.transpose(v)[..., None])[..., 0] # [P, N, M] # [P, M, 1] + ) # [N, P] + + tf.debugging.assert_equal(tf.shape(weight_space_prior_X), [N, P]) + tf.debugging.assert_equal(tf.shape(function_space_update_X), [N, P]) + + return weight_space_prior_X + function_space_update_X # [N, P] + + return WilsonSample() diff --git a/gpflux/sampling/sample.py b/gpflux/sampling/sample.py index 5b7b8c36..e09818e5 100644 --- a/gpflux/sampling/sample.py +++ b/gpflux/sampling/sample.py @@ -26,13 +26,13 @@ from gpflow.covariances import Kuf, Kuu from gpflow.inducing_variables import InducingVariables from gpflow.kernels import Kernel -from gpflow.utilities import Dispatcher +from gpflux.feature_decomposition_kernels import KernelWithFeatureDecomposition from gpflux.math import compute_A_inv_b -from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition from gpflux.sampling.utils import draw_conditional_sample -efficient_sample = Dispatcher("efficient_sample") +from .dispatch import efficient_sample + """ A function that returns a :class:`Sample` of a GP posterior. """ @@ -155,11 +155,12 @@ def _efficient_sample_matheron_rule( :param q_sqrt: A tensor with the shape ``[P, M, M]``. :param whiten: Determines the parameterisation of the inducing variables. """ + L = tf.shape(kernel.feature_coefficients)[0] # num eigenfunctions # noqa: F841 M, P = tf.shape(q_mu)[0], tf.shape(q_mu)[1] # num inducing, num output heads prior_weights = tf.sqrt(kernel.feature_coefficients) * tf.random.normal( - (L, P), dtype=default_float() + (L, P), dtype=default_float() # [L, 1], [L,P] ) # [L, P] u_sample_noise = tf.matmul( @@ -167,6 +168,7 @@ def _efficient_sample_matheron_rule( tf.random.normal((P, M, 1), dtype=default_float()), # [P, M, M] # [P, M, 1] ) # [P, M, 1] Kmm = Kuu(inducing_variable, kernel, jitter=default_jitter()) # [M, M] + tf.debugging.assert_equal(tf.shape(Kmm), [M, M]) u_sample = q_mu + tf.linalg.matrix_transpose(u_sample_noise[..., 0]) # [M, P] @@ -175,9 +177,12 @@ def _efficient_sample_matheron_rule( u_sample = tf.matmul(Luu, u_sample) # [M, P] phi_Z = kernel.feature_functions(inducing_variable.Z) # [M, L] - weight_space_prior_Z = phi_Z @ prior_weights # [M, P] + + weight_space_prior_Z = tf.matmul(phi_Z, prior_weights) # [M, L] # [L, P] # [M, P] + diff = u_sample - weight_space_prior_Z # [M, P] v = compute_A_inv_b(Kmm, diff) # [M, P] + tf.debugging.assert_equal(tf.shape(v), [M, P]) class WilsonSample(Sample): @@ -188,9 +193,12 @@ def __call__(self, X: TensorType) -> tf.Tensor: """ N = tf.shape(X)[0] phi_X = kernel.feature_functions(X) # [N, L] - weight_space_prior_X = phi_X @ prior_weights # [N, P] + + weight_space_prior_X = tf.matmul(phi_X, prior_weights) # [N, L] # [L, P] # [N, P] + Knm = tf.linalg.matrix_transpose(Kuf(inducing_variable, kernel, X)) # [N, M] - function_space_update_X = Knm @ v # [N, P] + + function_space_update_X = tf.matmul(Knm, v) # [N, M] # [M, P] # [N, P] tf.debugging.assert_equal(tf.shape(weight_space_prior_X), [N, P]) tf.debugging.assert_equal(tf.shape(function_space_update_X), [N, P]) diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_orthogonal_ff.py b/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_orthogonal_ff.py new file mode 100644 index 00000000..fcfc3970 --- /dev/null +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_orthogonal_ff.py @@ -0,0 +1,184 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import tensorflow as tf +from tensorflow.python.keras.testing_utils import layer_test +from tensorflow.python.keras.utils.kernelized_utils import inner_product + +import gpflow +from gpflow.kernels import SeparateIndependent, SharedIndependent + +from gpflux import feature_decomposition_kernels +from gpflux.feature_decomposition_kernels.multioutput import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, +) +from gpflux.layers.basis_functions.fourier_features.multioutput.random import ( + MultiOutputOrthogonalRandomFeatures, +) +from gpflux.layers.basis_functions.fourier_features.random.orthogonal import ORF_SUPPORTED_KERNELS +from tests.conftest import skip_serialization_tests + + +@pytest.fixture(name="n_dims", params=[1, 2, 3, 5, 10, 20]) +def _n_dims_fixture(request): + return request.param + + +@pytest.fixture(name="variance", params=[0.5, 1.0, 2.0]) +def _variance_fixture(request): + return request.param + + +@pytest.fixture(name="lengthscale", params=[0.1, 1.0, 5.0]) +def _lengthscale_fixture(request): + return request.param + + +@pytest.fixture(name="batch_size", params=[1, 10]) +def _batch_size_fixture(request): + return request.param + + +@pytest.fixture(name="n_components", params=[1, 2, 4, 20, 100]) +def _n_features_fixture(request): + return request.param + + +@pytest.fixture(name="base_kernel_cls", params=list(ORF_SUPPORTED_KERNELS)) +def _base_kernel_cls_fixture(request): + return request.param + + +def test_orthogonal_fourier_features_can_approximate_multi_output_separate_kernel_multidim( + base_kernel_cls, variance, lengthscale, n_dims +): + n_components = 40000 + + x_rows = 20 + y_rows = 30 + # ARD + lengthscales = np.random.rand((n_dims)) * lengthscale + + print("size of sampled lengthscales") + print(lengthscales.shape) + + base_kernel = base_kernel_cls(variance=variance, lengthscales=lengthscales) + + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + actual_kernel_matrix = kernel.K(x, y, full_output_cov=False).numpy() + + # fourier_features = random_basis_func_cls(kernel, n_components, dtype=tf.float64) + fourier_features = MultiOutputOrthogonalRandomFeatures(kernel, n_components, dtype=tf.float64) + + feature_coefficients = np.ones((2, 2 * n_components, 1), dtype=np.float64) + + kernel = SeparateMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + + approx_kernel_matrix = kernel(x, y).numpy() + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +def test_orthogonal_fourier_features_can_approximate_multi_output_shared_kernel_multidim( + base_kernel_cls, variance, lengthscale, n_dims +): + n_components = 40000 + + x_rows = 20 + y_rows = 30 + # ARD + lengthscales = np.random.rand((n_dims)) * lengthscale + + print("size of sampled lengthscales") + print(lengthscales.shape) + + base_kernel = base_kernel_cls(variance=variance, lengthscales=lengthscales) + + kernel = SharedIndependent(kernel=base_kernel, output_dim=2) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + actual_kernel_matrix = kernel.K(x, y, full_output_cov=False).numpy() + + # fourier_features = random_basis_func_cls(kernel, n_components, dtype=tf.float64) + fourier_features = MultiOutputOrthogonalRandomFeatures(kernel, n_components, dtype=tf.float64) + + feature_coefficients = np.ones((2, 2 * n_components, 1), dtype=np.float64) + + kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + + approx_kernel_matrix = kernel(x, y).numpy() + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +""" +#TODO -- have a look at what layer_test actually does +@skip_serialization_tests +def test_keras_testing_util_layer_test_1D(kernel_cls, batch_size, n_components): + kernel = kernel_cls() + + tf.keras.utils.get_custom_objects()["RandomFourierFeatures"] = RandomFourierFeatures + layer_test( + RandomFourierFeatures, + kwargs={ + "kernel": kernel, + "n_components": n_components, + "input_dim": 1, + "dtype": "float64", + "dynamic": True, + }, + input_shape=(batch_size, 1), + input_dtype="float64", + ) + + +@skip_serialization_tests +def test_keras_testing_util_layer_test_multidim(kernel_cls, batch_size, n_dims, n_components): + kernel = kernel_cls() + + tf.keras.utils.get_custom_objects()["RandomFourierFeatures"] = RandomFourierFeatures + layer_test( + RandomFourierFeatures, + kwargs={ + "kernel": kernel, + "n_components": n_components, + "input_dim": n_dims, + "dtype": "float64", + "dynamic": True, + }, + input_shape=(batch_size, n_dims), + input_dtype="float64", + ) + +""" diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_rff.py b/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_rff.py new file mode 100644 index 00000000..f11b6f05 --- /dev/null +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_rff.py @@ -0,0 +1,324 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import tensorflow as tf +from tensorflow.python.keras.testing_utils import layer_test +from tensorflow.python.keras.utils.kernelized_utils import inner_product + +import gpflow +from gpflow.kernels import SeparateIndependent, SharedIndependent + +from gpflux import feature_decomposition_kernels +from gpflux.feature_decomposition_kernels.multioutput import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, +) +from gpflux.layers.basis_functions.fourier_features import ( + MultiOutputRandomFourierFeatures, + MultiOutputRandomFourierFeaturesCosine, +) +from gpflux.layers.basis_functions.fourier_features.random.base import RFF_SUPPORTED_KERNELS +from tests.conftest import skip_serialization_tests + + +@pytest.fixture(name="n_dims", params=[1, 2, 3, 5, 10, 20]) +def _n_dims_fixture(request): + return request.param + + +@pytest.fixture(name="variance", params=[0.5, 1.0, 2.0]) +def _variance_fixture(request): + return request.param + + +@pytest.fixture(name="lengthscale", params=[0.1, 1.0, 5.0]) +def _lengthscale_fixture(request): + return request.param + + +@pytest.fixture(name="batch_size", params=[1, 10]) +def _batch_size_fixture(request): + return request.param + + +@pytest.fixture(name="n_components", params=[1, 2, 4, 20, 100]) +def _n_features_fixture(request): + return request.param + + +@pytest.fixture(name="base_kernel_cls", params=list(RFF_SUPPORTED_KERNELS)) +def _base_kernel_cls_fixture(request): + return request.param + + +@pytest.fixture( + name="random_basis_func_cls", + params=[MultiOutputRandomFourierFeatures], +) +def _random_basis_func_cls_fixture(request): + return request.param + + +@pytest.fixture( + name="basis_func_cls", + params=[MultiOutputRandomFourierFeatures], +) +def _basis_func_cls_fixture(request): + return request.param + + +def test_throw_for_unsupported_separate_kernel(basis_func_cls): + base_kernel = gpflow.kernels.Constant() + kernel = gpflow.kernels.SeparateIndependent(kernels=[base_kernel]) + with pytest.raises(AssertionError) as excinfo: + basis_func_cls(kernel, n_components=1) + assert "Unsupported Kernel" in str(excinfo.value) + + +def test_throw_for_unsupported_shared_kernel(basis_func_cls): + base_kernel = gpflow.kernels.Constant() + kernel = SharedIndependent(kernel=base_kernel, output_dim=1) + with pytest.raises(AssertionError) as excinfo: + basis_func_cls(kernel, n_components=1) + assert "Unsupported Kernel" in str(excinfo.value) + + +def test_random_fourier_features_can_approximate_multi_output_separate_kernel_multidim( + random_basis_func_cls, base_kernel_cls, variance, lengthscale, n_dims +): + n_components = 40000 + + x_rows = 20 + y_rows = 30 + # ARD + lengthscales = np.random.rand((n_dims)) * lengthscale + + print("size of sampled lengthscales") + print(lengthscales.shape) + + base_kernel = base_kernel_cls(variance=variance, lengthscales=lengthscales) + + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + actual_kernel_matrix = kernel.K(x, y, full_output_cov=False).numpy() + + fourier_features = random_basis_func_cls(kernel, n_components, dtype=tf.float64) + + feature_coefficients = np.ones((2, 2 * n_components, 1), dtype=np.float64) + + kernel = SeparateMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + + approx_kernel_matrix = kernel(x, y).numpy() + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +def test_random_fourier_features_can_approximate_multi_output_shared_kernel_multidim( + random_basis_func_cls, base_kernel_cls, variance, lengthscale, n_dims +): + n_components = 40000 + + x_rows = 20 + y_rows = 30 + # ARD + lengthscales = np.random.rand((n_dims)) * lengthscale + + print("size of sampled lengthscales") + print(lengthscales.shape) + + base_kernel = base_kernel_cls(variance=variance, lengthscales=lengthscales) + + kernel = SharedIndependent(kernel=base_kernel, output_dim=2) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + actual_kernel_matrix = kernel.K(x, y, full_output_cov=False).numpy() + + fourier_features = random_basis_func_cls(kernel, n_components, dtype=tf.float64) + + feature_coefficients = np.ones((2, 2 * n_components, 1), dtype=np.float64) + + kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + + approx_kernel_matrix = kernel(x, y).numpy() + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +""" +#TODO -- still need to implement the orthogonal version +def test_orthogonal_random_features_can_approximate_kernel_multidim(variance, lengthscale, n_dims): + n_components = 20000 + + x_rows = 20 + y_rows = 30 + # ARD + lengthscales = np.random.rand((n_dims)) * lengthscale + + kernel = gpflow.kernels.SquaredExponential(variance=variance, lengthscales=lengthscales) + fourier_features = OrthogonalRandomFeatures(kernel, n_components, dtype=tf.float64) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + u = fourier_features(x) + v = fourier_features(y) + approx_kernel_matrix = inner_product(u, v) + + actual_kernel_matrix = kernel.K(x, y) + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + +""" + + +def test_random_multi_output_fourier_feature_layer_compute_covariance_of_shared_inducing_variables( + basis_func_cls, batch_size +): + """ + Ensure that the random fourier feature map can be used to approximate the covariance matrix + between the inducing point vectors of a sparse GP, with the condition that the number of latent + GP models is greater than one. + """ + n_components = 10000 + + base_kernel = gpflow.kernels.SquaredExponential() + kernel = SharedIndependent(kernel=base_kernel, output_dim=2) + fourier_features = basis_func_cls(kernel, n_components, dtype=tf.float64) + feature_coefficients = np.ones((2, 2 * n_components, 1), dtype=np.float64) + + feature_decomposition_kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + x_new = tf.ones(shape=(2 * batch_size + 1, 1), dtype=tf.float64) + + approx_kernel_matrix = feature_decomposition_kernel(x_new, x_new) + actual_kernel_matrix = kernel.K(x_new, x_new, full_output_cov=False) + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +def test_random_multi_output_fourier_feature_layer_compute_covariance_of_separate_inducing_variables( + basis_func_cls, batch_size +): + """ + Ensure that the random fourier feature map can be used to approximate the covariance matrix + between the inducing point vectors of a sparse GP, with the condition that the number of latent + GP models is greater than one. + """ + n_components = 10000 + + base_kernel = gpflow.kernels.SquaredExponential() + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + fourier_features = basis_func_cls(kernel, n_components, dtype=tf.float64) + feature_coefficients = np.ones((2, 2 * n_components, 1), dtype=np.float64) + + feature_decomposition_kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + x_new = tf.ones(shape=(2 * batch_size + 1, 1), dtype=tf.float64) + + approx_kernel_matrix = feature_decomposition_kernel(x_new, x_new) + actual_kernel_matrix = kernel.K(x_new, x_new, full_output_cov=False) + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +def test_separate_multi_output_fourier_features_shapes( + basis_func_cls, n_components, n_dims, batch_size +): + input_shape = (2, batch_size, n_dims) + base_kernel = gpflow.kernels.SquaredExponential(lengthscales=[1.0] * n_dims) + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + feature_functions = basis_func_cls(kernel, n_components, dtype=tf.float64) + output_shape = feature_functions.compute_output_shape(input_shape) + features = feature_functions(tf.ones(shape=(batch_size, n_dims))) + np.testing.assert_equal(features.shape, output_shape) + + +def test_shared_multi_output_fourier_features_shapes( + basis_func_cls, n_components, n_dims, batch_size +): + input_shape = (2, batch_size, n_dims) + base_kernel = gpflow.kernels.SquaredExponential(lengthscales=[1.0] * n_dims) + kernel = SharedIndependent(kernel=base_kernel, output_dim=2) + feature_functions = basis_func_cls(kernel, n_components, dtype=tf.float64) + output_shape = feature_functions.compute_output_shape(input_shape) + features = feature_functions(tf.ones(shape=(batch_size, n_dims))) + np.testing.assert_equal(features.shape, output_shape) + + +""" +#TODO -- have a look at what layer_test actually does +@skip_serialization_tests +def test_keras_testing_util_layer_test_1D(kernel_cls, batch_size, n_components): + kernel = kernel_cls() + + tf.keras.utils.get_custom_objects()["RandomFourierFeatures"] = RandomFourierFeatures + layer_test( + RandomFourierFeatures, + kwargs={ + "kernel": kernel, + "n_components": n_components, + "input_dim": 1, + "dtype": "float64", + "dynamic": True, + }, + input_shape=(batch_size, 1), + input_dtype="float64", + ) + + +@skip_serialization_tests +def test_keras_testing_util_layer_test_multidim(kernel_cls, batch_size, n_dims, n_components): + kernel = kernel_cls() + + tf.keras.utils.get_custom_objects()["RandomFourierFeatures"] = RandomFourierFeatures + layer_test( + RandomFourierFeatures, + kwargs={ + "kernel": kernel, + "n_components": n_components, + "input_dim": n_dims, + "dtype": "float64", + "dynamic": True, + }, + input_shape=(batch_size, n_dims), + input_dtype="float64", + ) + +""" diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_rff_cosine.py b/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_rff_cosine.py new file mode 100644 index 00000000..e4bd8e65 --- /dev/null +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_multioutput_rff_cosine.py @@ -0,0 +1,312 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import tensorflow as tf +from tensorflow.python.keras.testing_utils import layer_test +from tensorflow.python.keras.utils.kernelized_utils import inner_product + +import gpflow +from gpflow.kernels import SeparateIndependent, SharedIndependent + +from gpflux import feature_decomposition_kernels +from gpflux.feature_decomposition_kernels.multioutput import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, +) +from gpflux.helpers import construct_basic_kernel +from gpflux.layers.basis_functions.fourier_features import ( + MultiOutputRandomFourierFeatures, + MultiOutputRandomFourierFeaturesCosine, +) +from gpflux.layers.basis_functions.fourier_features.random.base import RFF_SUPPORTED_KERNELS +from tests.conftest import skip_serialization_tests + + +@pytest.fixture(name="n_dims", params=[1, 2, 3, 5, 10, 20]) +def _n_dims_fixture(request): + return request.param + + +# @pytest.fixture(name="variance", params=[0.5, 1.0, 2.0]) +@pytest.fixture(name="variance", params=[0.5]) +def _variance_fixture(request): + return request.param + + +# @pytest.fixture(name="lengthscale", params=[0.1, 1.0, 5.0]) +@pytest.fixture(name="lengthscale", params=[0.1]) +def _lengthscale_fixture(request): + return request.param + + +@pytest.fixture(name="batch_size", params=[1, 10]) +def _batch_size_fixture(request): + return request.param + + +@pytest.fixture(name="n_components", params=[1, 2, 4, 20, 100]) +def _n_features_fixture(request): + return request.param + + +@pytest.fixture(name="base_kernel_cls", params=list(RFF_SUPPORTED_KERNELS)) +def _base_kernel_cls_fixture(request): + return request.param + + +@pytest.fixture( + name="random_basis_func_cls", + params=[MultiOutputRandomFourierFeaturesCosine], +) +def _random_basis_func_cls_fixture(request): + return request.param + + +@pytest.fixture( + name="basis_func_cls", + params=[MultiOutputRandomFourierFeaturesCosine], +) +def _basis_func_cls_fixture(request): + return request.param + + +@pytest.mark.skip +def test_throw_for_unsupported_separate_kernel(basis_func_cls): + base_kernel = gpflow.kernels.Constant() + kernel = gpflow.kernels.SeparateIndependent(kernels=[base_kernel]) + with pytest.raises(AssertionError) as excinfo: + basis_func_cls(kernel, n_components=1) + assert "Unsupported Kernel" in str(excinfo.value) + + +@pytest.mark.skip +def test_throw_for_unsupported_shared_kernel(basis_func_cls): + base_kernel = gpflow.kernels.Constant() + kernel = SharedIndependent(kernel=base_kernel, output_dim=1) + with pytest.raises(AssertionError) as excinfo: + basis_func_cls(kernel, n_components=1) + assert "Unsupported Kernel" in str(excinfo.value) + + +@pytest.mark.parametrize("output_dim", [2]) +@pytest.mark.parametrize("n_components", [100]) +@pytest.mark.parametrize("size_dataset", [10]) +def test_separate_kernel_multioutput_rff_cosine( + n_components: int, output_dim: int, size_dataset: int, variance: float, lengthscale: float +) -> None: + x = tf.random.uniform((size_dataset, output_dim), dtype=tf.float64) + + lengthscales = np.random.rand((output_dim)) * lengthscale + lengthscales = tf.cast(lengthscales, dtype=tf.float64) + + base_kernel = gpflow.kernels.SquaredExponential(variance=variance, lengthscales=lengthscales) + + # kernel = construct_basic_kernel( + # [base_kernel for _ in range(output_dim)], share_hyperparams=False + # ) + + kernel = SeparateIndependent(kernels=[base_kernel for _ in range(output_dim)]) + + rff = MultiOutputRandomFourierFeaturesCosine( + kernel=kernel, n_components=n_components, dtype=tf.float64 + ) + output = rff(inputs=x) + + tf.debugging.assert_shapes([(output, [output_dim, size_dataset, n_components])]) + + +@pytest.mark.skip +def test_random_fourier_features_can_approximate_multi_output_separate_kernel_multidim( + random_basis_func_cls, base_kernel_cls, variance, lengthscale, n_dims +): + n_components = 40000 + x_rows = 20 + y_rows = 30 + + lengthscales = np.random.rand((n_dims)) * lengthscale + + base_kernel = base_kernel_cls(variance=variance, lengthscales=lengthscales) + + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + actual_kernel_matrix = kernel.K(x, y, full_output_cov=False).numpy() + + fourier_features = random_basis_func_cls(kernel, n_components, dtype=tf.float64) + + feature_coefficients = np.ones((2, n_components, 1), dtype=np.float64) + + kernel = SeparateMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + + approx_kernel_matrix = kernel(x, y).numpy() + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +@pytest.mark.skip +def test_random_fourier_features_can_approximate_multi_output_shared_kernel_multidim( + random_basis_func_cls, base_kernel_cls, variance, lengthscale, n_dims +): + n_components = 40000 + x_rows = 20 + y_rows = 30 + + lengthscales = np.random.rand((n_dims)) * lengthscale + + base_kernel = base_kernel_cls(variance=variance, lengthscales=lengthscales) + + kernel = SharedIndependent(kernel=base_kernel, output_dim=2) + + x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) + y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) + + actual_kernel_matrix = kernel.K(x, y, full_output_cov=False).numpy() + + fourier_features = random_basis_func_cls(kernel, n_components, dtype=tf.float64) + + feature_coefficients = np.ones((2, n_components, 1), dtype=np.float64) + + kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + + approx_kernel_matrix = kernel(x, y).numpy() + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +@pytest.mark.skip +def test_random_multi_output_fourier_feature_layer_compute_covariance_of_shared_inducing_variables( + basis_func_cls, batch_size +): + """ + Ensure that the random fourier feature map can be used to approximate the covariance matrix + between the inducing point vectors of a sparse GP, with the condition that the number of latent + GP models is greater than one. + """ + n_components = 10000 + + base_kernel = gpflow.kernels.SquaredExponential() + kernel = SharedIndependent(kernel=base_kernel, output_dim=2) + fourier_features = basis_func_cls(kernel, n_components, dtype=tf.float64) + feature_coefficients = np.ones((2, n_components, 1), dtype=np.float64) + + feature_decomposition_kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + x_new = tf.ones(shape=(2 * batch_size + 1, 1), dtype=tf.float64) + + approx_kernel_matrix = feature_decomposition_kernel(x_new, x_new) + actual_kernel_matrix = kernel.K(x_new, x_new, full_output_cov=False) + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +@pytest.mark.skip +def test_random_multi_output_fourier_feature_layer_compute_covariance_of_separate_inducing_variables( + basis_func_cls, batch_size +): + """ + Ensure that the random fourier feature map can be used to approximate the covariance matrix + between the inducing point vectors of a sparse GP, with the condition that the number of latent + GP models is greater than one. + """ + n_components = 10000 + + base_kernel = gpflow.kernels.SquaredExponential() + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + fourier_features = basis_func_cls(kernel, n_components, dtype=tf.float64) + feature_coefficients = np.ones((2, n_components, 1), dtype=np.float64) + + feature_decomposition_kernel = SharedMultiOutputKernelWithFeatureDecomposition( + kernel=None, + feature_functions=fourier_features, + feature_coefficients=feature_coefficients, + output_dim=2, + ) + x_new = tf.ones(shape=(2 * batch_size + 1, 1), dtype=tf.float64) + + approx_kernel_matrix = feature_decomposition_kernel(x_new, x_new) + actual_kernel_matrix = kernel.K(x_new, x_new, full_output_cov=False) + + np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) + + +@pytest.mark.skip +def test_multi_output_fourier_features_shapes(basis_func_cls, n_components, n_dims, batch_size): + input_shape = (2, batch_size, n_dims) + base_kernel = gpflow.kernels.SquaredExponential(lengthscales=[1.0] * n_dims) + kernel = SeparateIndependent(kernels=[base_kernel, base_kernel]) + feature_functions = basis_func_cls(kernel, n_components, dtype=tf.float64) + output_shape = feature_functions.compute_output_shape(input_shape) + features = feature_functions(tf.ones(shape=(batch_size, n_dims))) + np.testing.assert_equal(features.shape, output_shape) + + +""" +#TODO -- have a look at what layer_test actually does +@skip_serialization_tests +def test_keras_testing_util_layer_test_1D(kernel_cls, batch_size, n_components): + kernel = kernel_cls() + + tf.keras.utils.get_custom_objects()["RandomFourierFeatures"] = RandomFourierFeatures + layer_test( + RandomFourierFeatures, + kwargs={ + "kernel": kernel, + "n_components": n_components, + "input_dim": 1, + "dtype": "float64", + "dynamic": True, + }, + input_shape=(batch_size, 1), + input_dtype="float64", + ) + + +@skip_serialization_tests +def test_keras_testing_util_layer_test_multidim(kernel_cls, batch_size, n_dims, n_components): + kernel = kernel_cls() + + tf.keras.utils.get_custom_objects()["RandomFourierFeatures"] = RandomFourierFeatures + layer_test( + RandomFourierFeatures, + kwargs={ + "kernel": kernel, + "n_components": n_components, + "input_dim": n_dims, + "dtype": "float64", + "dynamic": True, + }, + input_shape=(batch_size, n_dims), + input_dtype="float64", + ) + +""" diff --git a/tests/gpflux/sampling/test_multioutput_sample.py b/tests/gpflux/sampling/test_multioutput_sample.py new file mode 100644 index 00000000..9db16344 --- /dev/null +++ b/tests/gpflux/sampling/test_multioutput_sample.py @@ -0,0 +1,225 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import tensorflow as tf + +import gpflow +from gpflow.config import default_float, default_jitter + +from gpflux.feature_decomposition_kernels import ( + SeparateMultiOutputKernelWithFeatureDecomposition, + SharedMultiOutputKernelWithFeatureDecomposition, +) +from gpflux.layers.basis_functions.fourier_features import MultiOutputRandomFourierFeaturesCosine +from gpflux.sampling.sample import Sample, efficient_sample + + +@pytest.fixture(name="base_kernel") +def _base_kernel_fixture(): + + return gpflow.kernels.SquaredExponential() + + +def _get_shared_kernel(base_kernel): + + return gpflow.kernels.SharedIndependent(kernel=base_kernel, output_dim=2) + + +def _get_separate_kernel(base_kernel): + + return gpflow.kernels.SeparateIndependent(kernels=[base_kernel, base_kernel]) + + +@pytest.fixture(name="shared_inducing_variable") +def _shared_inducing_variable_fixture(): + Z = np.linspace(-1, 1, 10).reshape(-1, 1) + + ind_var = gpflow.inducing_variables.InducingPoints(Z) + + return gpflow.inducing_variables.SharedIndependentInducingVariables(inducing_variable=ind_var) + + +@pytest.fixture(name="separate_inducing_variable") +def _separate_inducing_variable_fixture(): + Z = np.linspace(-1, 1, 10).reshape(-1, 1) + + ind_var = gpflow.inducing_variables.InducingPoints(Z) + + return gpflow.inducing_variables.SeparateIndependentInducingVariables( + inducing_variable_list=[ind_var, ind_var] + ) + + +@pytest.fixture(name="whiten", params=[True, False]) +def _whiten_fixture(request): + return request.param + + +def _get_shared_qmu_qsqrt(kernel, inducing_variable): + """Returns q_mu and q_sqrt for a kernel and inducing_variable""" + Z = inducing_variable.inducing_variable.Z.numpy() + Kzz = kernel(Z, full_cov=True, full_output_cov=False).numpy() + + q_sqrt = np.linalg.cholesky(Kzz) + default_jitter() * np.eye(Z.shape[0])[None, ...] + q_mu = q_sqrt @ np.random.randn(2, Z.shape[0], 1) + + return np.transpose(q_mu[..., 0]), q_sqrt + + +def _get_separate_qmu_qsqrt(kernel, inducing_variable): + """Returns q_mu and q_sqrt for a kernel and inducing_variable""" + Z = inducing_variable.inducing_variable_list[0].Z.numpy() + Kzz = kernel(Z, full_cov=True, full_output_cov=False).numpy() + + q_sqrt = np.linalg.cholesky(Kzz) + default_jitter() * np.eye(Z.shape[0])[None, ...] + q_mu = q_sqrt @ np.random.randn(2, Z.shape[0], 1) + + return np.transpose(q_mu[..., 0]), q_sqrt + + +def test_shared_conditional_sample(base_kernel, shared_inducing_variable, whiten): + """Smoke and consistency test for efficient sampling using MVN Conditioning""" + kernel = _get_shared_kernel(base_kernel) + q_mu, q_sqrt = _get_shared_qmu_qsqrt(kernel, shared_inducing_variable) + + sample_func = efficient_sample( + shared_inducing_variable, + kernel, + q_mu, + q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt), + whiten=whiten, + ) + + X = np.linspace(-1, 1, 100).reshape(-1, 1) + # Check for consistency - i.e. evaluating the sample at the + # same locations (X) returns the same value + np.testing.assert_array_almost_equal( + sample_func(X), + sample_func(X), + # MVN conditioning is numerically unstable. + # Notice how in the Wilson sampling we can use the default + # of decimal=7. + decimal=2, + ) + + +def test_separate_conditional_sample(base_kernel, separate_inducing_variable, whiten): + """Smoke and consistency test for efficient sampling using MVN Conditioning""" + kernel = _get_separate_kernel(base_kernel) + q_mu, q_sqrt = _get_separate_qmu_qsqrt(kernel, separate_inducing_variable) + + sample_func = efficient_sample( + separate_inducing_variable, + kernel, + q_mu, + q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt), + whiten=whiten, + ) + + X = np.linspace(-1, 1, 100).reshape(-1, 1) + # Check for consistency - i.e. evaluating the sample at the + # same locations (X) returns the same value + np.testing.assert_array_almost_equal( + sample_func(X), + sample_func(X), + # MVN conditioning is numerically unstable. + # Notice how in the Wilson sampling we can use the default + # of decimal=7. + decimal=2, + ) + + +def test_shared_wilson_efficient_sample(base_kernel, shared_inducing_variable, whiten): + """Smoke and consistency test for efficient sampling using Wilson""" + kernel = _get_shared_kernel(base_kernel) + + eigenfunctions = MultiOutputRandomFourierFeaturesCosine(kernel, 100, dtype=default_float()) + eigenvalues = np.ones((2, 100, 1), dtype=default_float()) + # To apply Wilson sampling we require the features and eigenvalues of the kernel + kernel2 = SharedMultiOutputKernelWithFeatureDecomposition(kernel, eigenfunctions, eigenvalues) + q_mu, q_sqrt = _get_shared_qmu_qsqrt(kernel, shared_inducing_variable) + + sample_func = efficient_sample( + shared_inducing_variable, + kernel2, + tf.convert_to_tensor(q_mu), + q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt), + whiten=whiten, + ) + + X = np.linspace(-1, 0, 100).reshape(-1, 1) + # Check for consistency - i.e. evaluating the sample at the + # same locations (X) returns the same value + np.testing.assert_array_almost_equal( + sample_func(X), + sample_func(X), + ) + + +def test_separate_wilson_efficient_sample(base_kernel, separate_inducing_variable, whiten): + """Smoke and consistency test for efficient sampling using Wilson""" + kernel = _get_separate_kernel(base_kernel) + + eigenfunctions = MultiOutputRandomFourierFeaturesCosine(kernel, 100, dtype=default_float()) + eigenvalues = np.ones((2, 100, 1), dtype=default_float()) + # To apply Wilson sampling we require the features and eigenvalues of the kernel + kernel2 = SeparateMultiOutputKernelWithFeatureDecomposition(kernel, eigenfunctions, eigenvalues) + q_mu, q_sqrt = _get_separate_qmu_qsqrt(kernel, separate_inducing_variable) + + sample_func = efficient_sample( + separate_inducing_variable, + kernel2, + tf.convert_to_tensor(q_mu), + q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt), + whiten=whiten, + ) + + X = np.linspace(-1, 0, 100).reshape(-1, 1) + # Check for consistency - i.e. evaluating the sample at the + # same locations (X) returns the same value + np.testing.assert_array_almost_equal( + sample_func(X), + sample_func(X), + ) + + +class SampleMock(Sample): + def __init__(self, a): + self.a = a + + def __call__(self, X): + return self.a * X + + +def test_adding_samples(): + X = np.random.randn(100, 2) + + sample1 = SampleMock(1.0) + sample2 = SampleMock(2.0) + sample3 = sample1 + sample2 + np.testing.assert_array_almost_equal(sample3(X), sample1(X) + sample2(X)) + + +def test_adding_sample_and_mean_function(): + X = np.random.randn(100, 2) + + mean_function = gpflow.mean_functions.Identity() + sample = SampleMock(3.0) + + sample_and_mean_function = sample + mean_function + + np.testing.assert_array_almost_equal(sample_and_mean_function(X), sample(X) + mean_function(X)) diff --git a/tests/gpflux/sampling/test_sample.py b/tests/gpflux/sampling/test_sample.py index 70b27588..d98e8c81 100644 --- a/tests/gpflux/sampling/test_sample.py +++ b/tests/gpflux/sampling/test_sample.py @@ -20,8 +20,10 @@ import gpflow from gpflow.config import default_float, default_jitter +from gpflux.feature_decomposition_kernels.kernel_with_feature_decomposition import ( + KernelWithFeatureDecomposition, +) from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine -from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition from gpflux.sampling.sample import Sample, efficient_sample