Skip to content

Commit

Permalink
Fix hierarchical tests
Browse files Browse the repository at this point in the history
  • Loading branch information
igoumiri committed May 15, 2024
1 parent d1e82ce commit 3d61023
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 44 deletions.
1 change: 1 addition & 0 deletions .github/workflows/develop-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
python tests/precompute/fast_posterior_mean.py
python tests/scale_opt.py
python tests/experimental/shear.py
python tests/experimental/hierarchical.py
- name: Optimize Tests
if: matrix.test-group == 'optimize'
run: python tests/optimize.py
Expand Down
27 changes: 24 additions & 3 deletions MuyGPyS/gp/deformation/anisotropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from MuyGPyS._src.util import auto_str
from MuyGPyS.gp.deformation.deformation_fn import DeformationFn
from MuyGPyS.gp.deformation.metric import MetricFn
from MuyGPyS.gp.hyperparameter import VectorParam, NamedVectorParam
from MuyGPyS.gp.hyperparameter import ScalarParam, VectorParam, NamedVectorParam
from MuyGPyS.gp.hyperparameter.experimental import (
HierarchicalParam,
NamedHierarchicalVectorParam,
)


@auto_str
Expand Down Expand Up @@ -37,8 +41,18 @@ def __init__(
metric: MetricFn,
length_scale: VectorParam,
):
name = "length_scale"
params = length_scale._params
# This is brittle and should be refactored
if all(isinstance(p, ScalarParam) for p in params):
self.length_scale = NamedVectorParam(name, length_scale)
elif all(isinstance(p, HierarchicalParam) for p in params):
self.length_scale = NamedHierarchicalVectorParam(name, length_scale)
else:
raise ValueError(
"Expected uniform vector of ScalarParam or HierarchicalParam type for length_scale"
)
self.metric = metric
self.length_scale = NamedVectorParam("length_scale", length_scale)

def __call__(self, dists: mm.ndarray, **length_scales) -> mm.ndarray:
"""
Expand Down Expand Up @@ -70,7 +84,14 @@ def __call__(self, dists: mm.ndarray, **length_scales) -> mm.ndarray:
f"Difference tensor of shape {dists.shape} must have final "
f"dimension size of {len(self.length_scale)}"
)
return self.metric(dists / self.length_scale(**length_scales))
length_scale = self.length_scale(**length_scales)
# This is brittle and similar to what we do in Isotropy.
if isinstance(length_scale, mm.ndarray) and len(length_scale.shape) > 0:
shape = [None] * dists.ndim
shape[0] = slice(None)
shape[-1] = slice(None)
length_scale = length_scale.T[tuple(shape)]
return self.metric(dists / length_scale)

@mpi_chunk(return_count=1)
def pairwise_tensor(
Expand Down
1 change: 1 addition & 0 deletions MuyGPyS/gp/hyperparameter/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
HierarchicalParameter,
HierarchicalParameter as HierarchicalParam,
NamedHierarchicalParameter as NamedHierarchicalParam,
NamedHierarchicalVectorParameter as NamedHierarchicalVectorParam,
sample_knots,
)
30 changes: 27 additions & 3 deletions MuyGPyS/gp/hyperparameter/experimental/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ class HierarchicalParameter:
knot_features:
Tensor of floats of shape `(knot_count, feature_count)`
containing the feature vectors for each knot.
knot_values:
knot_params:
List of scalar hyperparameters of length `knot_count`
containing the initial values and optimization bounds for each knot.
Float values will be converted to fixed scalar hyperparameters.
kernel:
Initialized higher-level GP kernel.
"""
Expand Down Expand Up @@ -108,7 +107,9 @@ def name(self) -> str:
def knot_values(self) -> mm.ndarray:
return self._params()

def __call__(self, batch_features, **kwargs) -> float:
def __call__(self, batch_features=None, **kwargs) -> float:
if batch_features is None:
raise TypeError("batch_features keyword argument is required")
params, kwargs = self._params.filter_kwargs(**kwargs)
solve = mm.linalg.solve(
self._Kin_higher + self._noise() * mm.eye(self._knot_count),
Expand Down Expand Up @@ -159,6 +160,29 @@ def populate(self, hyperparameters: Dict) -> None:
self._params.populate(hyperparameters)


class NamedHierarchicalVectorParameter(NamedVectorParam):
def __init__(self, name: str, param: VectorParam):
self._params = [
NamedHierarchicalParameter(name + str(i), p)
for i, p in enumerate(param._params)
]
self._name = name

def filter_kwargs(self, **kwargs) -> Tuple[Dict, Dict]:
params = {
key: kwargs[key] for key in kwargs if key.startswith(self._name)
}
kwargs = {
key: kwargs[key] for key in kwargs if not key.startswith(self._name)
}
if "batch_features" in kwargs:
for p in self._params:
params.setdefault(
p.name(), p(kwargs["batch_features"], **params)
)
return params, kwargs


def sample_knots(feature_count: int, knot_count: int) -> mm.ndarray:
"""
Samples knots from feature matrix.
Expand Down
28 changes: 18 additions & 10 deletions experimental/nonstationary_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For simplicity, we start with an isotropic distortion so we only need to use a single `HierarchicalNonstationaryHyperparameter`.\n",
"For simplicity, we start with an isotropic distortion so we only need to use a single `HierarchicalParameter`.\n",
"Let's also build a GP with a fixed length scale for comparison."
]
},
Expand Down Expand Up @@ -234,7 +234,7 @@
"metadata": {},
"source": [
"We can visualize the knots and the resulting `length_scale` surface over the domain of the function.\n",
"Unlike `ScalarHyperparameter`, `HierarchicalNonstationaryHyperparameter` takes an array of feature vectors for each point where you would like to evaluate the local value of the hyperparameter."
"Unlike `Parameter`, `HierarchicalParameter` takes an array of feature vectors for each point where you would like to evaluate the local value of the hyperparameter."
]
},
{
Expand Down Expand Up @@ -438,7 +438,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The knot values of hierarchical nonstationary hyperparameters can be optimized using like any other hyperparameters, using the `optimize_from_tensors` utility. But first, we need to initialize them as `ScalarHyperparameter`s with bounds rather than as fixed values."
"The knot values of hierarchical nonstationary hyperparameters can be optimized like any other hyperparameters, using the `optimize_from_tensors` utility. But first, we need to initialize them as `Parameter`s with bounds rather than as fixed values."
]
},
{
Expand Down Expand Up @@ -547,7 +547,7 @@
"outputs": [],
"source": [
"from MuyGPyS.optimize import Bayes_optimize\n",
"from MuyGPyS.optimize.loss import lool_fn, mse_fn"
"from MuyGPyS.optimize.loss import mse_fn"
]
},
{
Expand Down Expand Up @@ -740,9 +740,9 @@
"for axi in axes:\n",
" for ax in axi:\n",
" ax.set_ylim([-1, 1.5])\n",
"# ax.plot(xs, ys, label=\"truth\")\n",
" for knot in knot_features:\n",
" ax.axvline(x=knot)\n",
" ax.axvline(x=knot, lw=0.5, c='gray')\n",
"\n",
"axes[0, 0].set_title(\"flat fixed\")\n",
"axes[0, 0].plot(test_features, mean_flat_fixed, \"-\", label=\"flat fixed\")\n",
"axes[0, 0].fill_between(\n",
Expand All @@ -752,6 +752,7 @@
" facecolor=\"C1\",\n",
" alpha=0.2,\n",
")\n",
"\n",
"axes[0, 1].set_title(\"hierarchical fixed\")\n",
"axes[0, 1].plot(test_features, mean_hierarchical_fixed, \"-\", label=\"hierarchical fixed\")\n",
"axes[0, 1].fill_between(\n",
Expand All @@ -761,6 +762,7 @@
" facecolor=\"C2\",\n",
" alpha=0.2,\n",
")\n",
"\n",
"axes[1, 0].set_title(\"flat optimized\")\n",
"axes[1, 0].plot(test_features, mean_flat_opt, \"-\", label=\"flat optimized\")\n",
"axes[1, 0].fill_between(\n",
Expand All @@ -770,6 +772,7 @@
" facecolor=\"C3\",\n",
" alpha=0.2,\n",
")\n",
"\n",
"axes[1, 1].set_title(\"hierarchical optimized\")\n",
"axes[1, 1].plot(test_features, mean_hierarchical_opt, \"-\", label=\"hierarchical optimized\")\n",
"axes[1, 1].fill_between(\n",
Expand All @@ -779,11 +782,16 @@
" facecolor=\"C3\",\n",
" alpha=0.2,\n",
")\n",
"for knot in knot_features:\n",
" ax.axvline(x=knot)\n",
"plt.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -802,7 +810,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down
62 changes: 34 additions & 28 deletions tests/experimental/nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
from MuyGPyS.gp import MuyGPS
from MuyGPyS.gp.kernels import Matern, RBF
from MuyGPyS.gp.deformation import l2, Isotropy, Anisotropy
from MuyGPyS.gp.hyperparameter import ScalarParam
from MuyGPyS.gp.hyperparameter import (
Parameter,
VectorParameter,
)
from MuyGPyS.gp.hyperparameter.experimental import (
HierarchicalNonstationaryHyperparameter,
HierarchicalParameter,
NamedHierarchicalParam,
NamedHierarchicalVectorParam,
sample_knots,
)
from MuyGPyS.gp.tensors import (
make_train_tensors,
batch_features_tensor,
)
from MuyGPyS.gp.tensors import batch_features_tensor
from MuyGPyS.neighbors import NN_Wrapper
from MuyGPyS.optimize.batch import sample_batch

Expand Down Expand Up @@ -54,24 +56,28 @@ def test_hierarchical_nonstationary_hyperparameter(
response_count=1,
)
knot_features = train["input"]
knot_values = train["output"]
knot_values = VectorParameter(
*[Parameter(x) for x in np.squeeze(train["output"])]
)
batch_features = test["input"]
hyp = HierarchicalNonstationaryHyperparameter(
knot_features,
knot_values,
kernel,
hyp = NamedHierarchicalParam(
"custom_param_name",
HierarchicalParameter(
knot_features,
knot_values,
kernel,
),
)
hyperparameters = hyp(batch_features)
_check_ndarray(
self.assertEqual, hyperparameters, mm.ftype, shape=(batch_count, 1)
self.assertEqual, hyperparameters, mm.ftype, shape=(batch_count,)
)

@parameterized.parameters(
(
(
feature_count,
type(knot_values[0]),
high_level_kernel,
type(high_level_kernel).__name__,
deformation,
)
for feature_count in [2, 17]
Expand All @@ -80,35 +86,35 @@ def test_hierarchical_nonstationary_hyperparameter(
sample_knots(feature_count=feature_count, knot_count=knot_count)
]
for knot_values in [
np.random.uniform(size=knot_count),
[ScalarParam(i) for i in range(knot_count)],
VectorParameter(*[Parameter(i) for i in range(knot_count)]),
]
for high_level_kernel in [RBF(), Matern()]
for deformation in [
Isotropy(
l2,
length_scale=HierarchicalNonstationaryHyperparameter(
length_scale=HierarchicalParameter(
knot_features, knot_values, high_level_kernel
),
),
Anisotropy(
l2,
**{
f"length_scale{i}": HierarchicalNonstationaryHyperparameter(
knot_features,
knot_values,
high_level_kernel,
)
for i in range(feature_count)
},
VectorParameter(
*[
HierarchicalParameter(
knot_features,
knot_values,
high_level_kernel,
)
for _ in range(feature_count)
]
),
),
]
)
)
def test_hierarchical_nonstationary_rbf(
self,
feature_count,
knot_values_type,
high_level_kernel,
deformation,
):
Expand All @@ -133,7 +139,7 @@ def test_hierarchical_nonstationary_rbf(
batch_indices, batch_nn_indices = sample_batch(
nbrs_lookup, batch_count, data_count
)
(_, pairwise_diffs, _, _) = make_train_tensors(
(_, pairwise_diffs, _, _) = muygps.make_train_tensors(
batch_indices,
batch_nn_indices,
data["input"],
Expand All @@ -142,7 +148,7 @@ def test_hierarchical_nonstationary_rbf(

batch_features = batch_features_tensor(data["input"], batch_indices)

Kin = muygps.kernel(pairwise_diffs, batch_features)
Kin = muygps.kernel(pairwise_diffs, batch_features=batch_features)

_check_ndarray(
self.assertEqual,
Expand Down

0 comments on commit 3d61023

Please sign in to comment.