Skip to content

Commit

Permalink
Handle hierarchical smoothness for Matern
Browse files Browse the repository at this point in the history
  • Loading branch information
igoumiri committed Sep 4, 2024
1 parent 6f9669e commit d2e2dd9
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 9 deletions.
2 changes: 1 addition & 1 deletion MuyGPyS/gp/deformation/isotropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Isotropy(DeformationFn):
def __init__(
self,
metric: MetricFn,
length_scale: ScalarParam,
length_scale: Union[ScalarParam, HierarchicalParam],
):
# This is brittle and should be refactored
if isinstance(length_scale, ScalarParam):
Expand Down
13 changes: 12 additions & 1 deletion MuyGPyS/gp/hyperparameter/experimental/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def filter_kwargs(self, **kwargs) -> Tuple[Dict, Dict]:
lower[self._name] = self(kwargs["batch_features"], **params)
return lower, kwargs

def apply_fn(self, fn: Callable, name: str) -> Callable:
def apply_fn(self, fn: Callable) -> Callable:
def applied_fn(*args, **kwargs):
lower, kwargs = self.filter_kwargs(**kwargs)
return fn(*args, **lower, **kwargs)
Expand Down Expand Up @@ -157,6 +157,17 @@ def append_lists(
def populate(self, hyperparameters: Dict) -> None:
self._params.populate(hyperparameters)

def fixed(self) -> bool:
"""
Report whether the parameter is fixed, and is to be ignored during
optimization.
Returns:
`True` if fixed, `False` otherwise.
"""
# return self._params.fixed()
return False


class NamedHierarchicalVectorParameter(NamedVectorParam):
def __init__(self, name: str, param: VectorParam):
Expand Down
2 changes: 1 addition & 1 deletion MuyGPyS/gp/hyperparameter/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def fixed(self) -> bool:
Returns:
`True` if fixed, `False` otherwise.
"""
return mm.all(param._fixed for param in self._params)
return mm.all([param._fixed for param in self._params])


class NamedVectorParameter(VectorParameter):
Expand Down
24 changes: 18 additions & 6 deletions MuyGPyS/gp/kernels/matern.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
>>> Kcross = kern(crosswise_diffs)
"""

from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union

import MuyGPyS._src.math as mm
from MuyGPyS._src.gp.kernels import (
Expand All @@ -55,18 +55,22 @@
l2,
)
from MuyGPyS.gp.hyperparameter import ScalarParam, NamedParam
from MuyGPyS.gp.hyperparameter.experimental import (
HierarchicalParam,
NamedHierarchicalParam,
)
from MuyGPyS.gp.kernels import KernelFn


def _set_matern_fn(
smoothness: ScalarParam,
smoothness: Union[NamedParam, NamedHierarchicalParam],
_backend_05_fn: Callable = _matern_05_fn,
_backend_15_fn: Callable = _matern_15_fn,
_backend_25_fn: Callable = _matern_25_fn,
_backend_inf_fn: Callable = _matern_inf_fn,
_backend_gen_fn: Callable = _matern_gen_fn,
):
if smoothness.fixed() is True:
if smoothness.fixed():
if smoothness() == 0.5:
return _backend_05_fn
elif smoothness() == 1.5:
Expand Down Expand Up @@ -119,17 +123,25 @@ class Matern(KernelFn):

def __init__(
self,
smoothness: ScalarParam = ScalarParam(0.5),
smoothness: Union[ScalarParam, HierarchicalParam] = ScalarParam(0.5),
deformation: DeformationFn = Isotropy(
l2, length_scale=ScalarParam(1.0)
),
_backend_ones: Callable = mm.ones,
_backend_zeros: Callable = mm.zeros,
_backend_squeeze: Callable = mm.squeeze,
**_backend_fns
**_backend_fns,
):
super().__init__(deformation=deformation)
self.smoothness = NamedParam("smoothness", smoothness)
if isinstance(smoothness, ScalarParam):
self.smoothness = NamedParam("smoothness", smoothness)
elif isinstance(smoothness, HierarchicalParam):
self.smoothness = NamedHierarchicalParam("smoothness", smoothness)
else:
raise ValueError(
"Expected ScalarParam type for smoothness, not "
f"{type(smoothness)}"
)
self._backend_ones = _backend_ones
self._backend_zeros = _backend_zeros
self._backend_squeeze = _backend_squeeze
Expand Down
92 changes: 92 additions & 0 deletions tests/experimental/nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,98 @@ def test_hierarchical_nonstationary_rbf(
shape=(batch_count, nn_count, nn_count),
)

@parameterized.parameters(
(
(
feature_count,
type(high_level_kernel).__name__,
smoothness,
deformation,
)
for feature_count in [2, 17]
for knot_count in [10]
for knot_features in [
sample_knots(feature_count=feature_count, knot_count=knot_count)
]
for knot_values in [
VectorParameter(*[Parameter(i) for i in range(knot_count)]),
]
for high_level_kernel in [RBF(), Matern()]
for smoothness, deformation in [
(
Parameter(1.5),
Isotropy(
l2,
length_scale=Parameter(1),
),
),
(
HierarchicalParameter(
knot_features, knot_values, high_level_kernel
),
Isotropy(
l2,
length_scale=Parameter(1),
),
),
(
Parameter(1.5),
Isotropy(
l2,
length_scale=HierarchicalParameter(
knot_features, knot_values, high_level_kernel
),
),
),
]
)
)
def test_hierarchical_nonstationary_matern(
self,
feature_count,
high_level_kernel,
smoothness,
deformation,
):
muygps = MuyGPS(
kernel=Matern(smoothness=smoothness, deformation=deformation),
)

# prepare data
data_count = 1000
data = _make_gaussian_dict(
data_count=data_count,
feature_count=feature_count,
response_count=1,
)

# neighbors and differences
nn_count = 30
nbrs_lookup = NN_Wrapper(
data["input"], nn_count, nn_method="exact", algorithm="ball_tree"
)
batch_count = 200
batch_indices, batch_nn_indices = sample_batch(
nbrs_lookup, batch_count, data_count
)
(_, pairwise_diffs, _, _) = muygps.make_train_tensors(
batch_indices,
batch_nn_indices,
data["input"],
data["output"],
)

batch_features = batch_features_tensor(data["input"], batch_indices)

Kin = muygps.kernel(pairwise_diffs, batch_features=batch_features)

_check_ndarray(
self.assertEqual,
Kin,
mm.ftype,
shape=(batch_count, nn_count, nn_count),
)


if __name__ == "__main__":
absltest.main()

0 comments on commit d2e2dd9

Please sign in to comment.