Skip to content

Commit

Permalink
Add batch_shape property to GP model class
Browse files Browse the repository at this point in the history
Implements #2301.

TODO: Verify compatibility with the botorch setup of other models
  • Loading branch information
Balandat committed Jun 4, 2023
1 parent f73fa7d commit 14bc8b0
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 19 deletions.
18 changes: 18 additions & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import torch

from .gp import GP
from .pyro import _PyroMixin # This will only contain functions if Pyro is installed

Expand Down Expand Up @@ -40,6 +42,11 @@ class ApproximateGP(GP, _PyroMixin):
>>> # test_x = ...;
>>> model(test_x) # Returns the approximate GP latent function at test_x
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
:ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
independent of the internal representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""

def __init__(self, variational_strategy):
Expand All @@ -49,6 +56,17 @@ def __init__(self, variational_strategy):
def forward(self, x):
raise NotImplementedError

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
return self.variational_strategy.batch_shape

def pyro_guide(self, input, beta=1.0, name_prefix=""):
r"""
(For Pyro integration only). The component of a `pyro.guide` that
Expand Down
30 changes: 22 additions & 8 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class ExactGP(GP):
>>> # test_x = ...;
>>> model(test_x) # Returns the GP latent function at test_x
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
:ivar torch.Size batch_shape: The batch shape of the model. This is a batch shape from an I/O perspective,
independent of the internal representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""

def __init__(self, train_inputs, train_targets, likelihood):
Expand All @@ -71,6 +76,17 @@ def __init__(self, train_inputs, train_targets, likelihood):

self.prediction_strategy = None

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
return self.train_inputs[0].shape[:-2]

@property
def train_targets(self):
return self._train_targets
Expand Down Expand Up @@ -160,8 +176,6 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
"all test independent caches exist. Call the model on some data first!"
)

model_batch_shape = self.train_inputs[0].shape[:-2]

if not isinstance(inputs, list):
inputs = [inputs]

Expand All @@ -184,17 +198,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):

# Check whether we can properly broadcast batch dimensions
try:
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
torch.broadcast_shapes(self.batch_shape, target_batch_shape)
except RuntimeError:
raise RuntimeError(
f"Model batch shape ({model_batch_shape}) and target batch shape "
f"Model batch shape ({self.batch_shape}) and target batch shape "
f"({target_batch_shape}) are not broadcastable."
)

if len(model_batch_shape) > len(input_batch_shape):
input_batch_shape = model_batch_shape
if len(model_batch_shape) > len(target_batch_shape):
target_batch_shape = model_batch_shape
if len(self.batch_shape) > len(input_batch_shape):
input_batch_shape = self.batch_shape
if len(self.batch_shape) > len(target_batch_shape):
target_batch_shape = self.batch_shape

# If input has no fantasy batch dimension but target does, we can save memory and computation by not
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
Expand Down
14 changes: 13 additions & 1 deletion gpytorch/models/gp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
#!/usr/bin/env python3

import torch

from ..module import Module


class GP(Module):
pass
@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
cls_name = self.__class__.__name__
raise NotImplementedError(f"{cls_name} does not define batch_shape property")
18 changes: 18 additions & 0 deletions gpytorch/models/model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ def __init__(self, *models):
)
self.likelihood = LikelihoodList(*[m.likelihood for m in models])

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the model.
This is a batch shape from an I/O perspective, independent of the internal
representation of the model. For a model with `(m)` outputs, a
`test_batch_shape x q x d`-shaped input to the model in eval mode returns a
distribution of shape `broadcast(test_batch_shape, model.batch_shape) x q x (m)`.
"""
batch_shape = self.models[0].batch_shape
if all(batch_shape == m.batch_shape for m in self.models[1:]):
return batch_shape
# TODO: Allow broadcasting of model batch shapes
raise NotImplementedError(
f"`{self.__class__.__name__}.batch_shape` is only supported if all "
"constituent models have the same `batch_shape`."
)

def forward_i(self, i, *args, **kwargs):
return self.models[i].forward(*args, **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions gpytorch/test/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_forward_train(self):
data = self.create_test_data()
likelihood, labels = self.create_likelihood_and_labels()
model = self.create_model(data, labels, likelihood)
self.assertEqual(model.batch_shape, data.shape[:-2]) # test batch_shape property
model.train()
output = model(data)
self.assertTrue(output.lazy_covariance_matrix.dim() == 2)
Expand All @@ -42,6 +43,7 @@ def test_batch_forward_train(self):
batch_data = self.create_batch_test_data()
likelihood, labels = self.create_batch_likelihood_and_labels()
model = self.create_model(batch_data, labels, likelihood)
self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property
model.train()
output = model(batch_data)
self.assertTrue(output.lazy_covariance_matrix.dim() == 3)
Expand All @@ -52,6 +54,7 @@ def test_multi_batch_forward_train(self):
batch_data = self.create_batch_test_data(batch_shape=torch.Size([2, 3]))
likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape=torch.Size([2, 3]))
model = self.create_model(batch_data, labels, likelihood)
self.assertEqual(model.batch_shape, batch_data.shape[:-2]) # test batch_shape property
model.train()
output = model(batch_data)
self.assertTrue(output.lazy_covariance_matrix.dim() == 4)
Expand Down
7 changes: 6 additions & 1 deletion gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,16 @@ def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Te
"""
Pre-processing step in __call__ to make x the same batch_shape as the inducing points
"""
batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
batch_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-2])
inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
x = x.expand(*batch_shape, *x.shape[-2:])
return x, inducing_points

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the variational strategy."""
return self.inducing_points.shape[:-2]

@property
def jitter_val(self) -> float:
if self._jitter_val is None:
Expand Down
17 changes: 10 additions & 7 deletions gpytorch/variational/lmc_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,24 @@ def __init__(
Module.__init__(self)
self.base_variational_strategy = base_variational_strategy
self.num_tasks = num_tasks
batch_shape = self.base_variational_strategy._variational_distribution.batch_shape
vdist_batch_shape = self.base_variational_strategy._variational_distribution.batch_shape

# Check if no functions
if latent_dim >= 0:
raise RuntimeError(f"latent_dim must be a negative indexed batch dimension: got {latent_dim}.")
if not (batch_shape[latent_dim] == num_latents or batch_shape[latent_dim] == 1):
if not (vdist_batch_shape[latent_dim] == num_latents or vdist_batch_shape[latent_dim] == 1):
raise RuntimeError(
f"Mismatch in num_latents: got a variational distribution of batch shape {batch_shape}, "
f"Mismatch in num_latents: got a variational distribution of batch shape {vdist_batch_shape}, "
f"expected the function dim {latent_dim} to be {num_latents}."
)
self.num_latents = num_latents
self.latent_dim = latent_dim

# Make the batch_shape
self.batch_shape = list(batch_shape)
del self.batch_shape[self.latent_dim]
self.batch_shape = torch.Size(self.batch_shape)
self._batch_shape = vdist_batch_shape[: self.latent_dim] + vdist_batch_shape[self.latent_dim + 1 :]

# LCM coefficients
lmc_coefficients = torch.randn(*batch_shape, self.num_tasks)
lmc_coefficients = torch.randn(*vdist_batch_shape, self.num_tasks)
self.register_parameter("lmc_coefficients", torch.nn.Parameter(lmc_coefficients))

if jitter_val is None:
Expand All @@ -145,6 +143,11 @@ def __init__(
else:
self.jitter_val = jitter_val

@property
def batch_shape(self) -> torch.Size:
r"""The batch shape of the variational strategy."""
return self._batch_shape

@property
def prior_distribution(self) -> MultivariateNormal:
return self.base_variational_strategy.prior_distribution
Expand Down
4 changes: 4 additions & 0 deletions test/models/test_exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def test_batch_forward_then_nonbatch_forward_eval(self):
batch_data = self.create_batch_test_data()
likelihood, labels = self.create_batch_likelihood_and_labels()
model = self.create_model(batch_data, labels, likelihood)

# test batch_shape property
self.assertEqual(model.batch_shape, batch_data.shape[:-2])

model.eval()
output = model(batch_data)

Expand Down
5 changes: 3 additions & 2 deletions test/models/test_variational_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

class GPClassificationModel(ApproximateGP):
def __init__(self, train_x, use_inducing=False):
variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=train_x.shape[:-2])
inducing_points = torch.randn(50, train_x.size(-1)) if use_inducing else train_x
batch_shape = train_x.shape[:-2]
variational_distribution = CholeskyVariationalDistribution(train_x.size(-2), batch_shape=batch_shape)
inducing_points = torch.randn(*batch_shape, 50, train_x.size(-1)) if use_inducing else train_x
strategy_cls = VariationalStrategy
variational_strategy = strategy_cls(
self, inducing_points, variational_distribution, learn_inducing_locations=use_inducing
Expand Down

0 comments on commit 14bc8b0

Please sign in to comment.