Skip to content

Commit

Permalink
reattaching nonstationary hyperparameters to optimization chassis (#229)
Browse files Browse the repository at this point in the history
* removed overdense kriging weight visualization from the neighborhood visualization tutorial.

* added NamedParameter backend interface so that optimization parameters know their names. 

* added VectorParameter and NamedVectorParameter backend interface for the anisotropic deformation.

* simplified optimization hyperparameter handling.

* removed embed_fn abstraction from DeformationFn and placed it inside of the hyperparameter classes

* reimplemented hierarchical hyperparameter and hooked it back into the optimizer

* updated tests with new Anisotropy api

* fix for torch backend
  • Loading branch information
bwpriest authored May 3, 2024
1 parent 222b851 commit d1e82ce
Show file tree
Hide file tree
Showing 29 changed files with 748 additions and 695 deletions.
20 changes: 10 additions & 10 deletions MuyGPyS/_test/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def plot_sample(self):
plt.show()

def plot_kriging_weights(self, idx, nbrs_lookup):
fig, axes = plt.subplots(1, 3, figsize=(19, 4))
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

singleton = [i for i, n in enumerate(self._test_mask) if n][idx]
x_coord = singleton % self.points_per_dim
Expand Down Expand Up @@ -343,23 +343,23 @@ def plot_kriging_weights(self, idx, nbrs_lookup):

vnorm = colors.LogNorm(vmin=vmin, vmax=vmax)

self._label_ax(axes[0], "Kriging Weights (all)")
im0 = axes[0].imshow(self._make_im(kriging_weights_all), norm=vnorm)
axes[0].plot(x_coord, y_coord, "r+")
# self._label_ax(axes[0], "Kriging Weights (all)")
# im0 = axes[0].imshow(self._make_im(kriging_weights_all), norm=vnorm)
# axes[0].plot(x_coord, y_coord, "r+")

self._label_ax(axes[1], "Kriging Weights (train)")
axes[1].imshow(
self._label_ax(axes[0], "Kriging Weights (train)")
im0 = axes[0].imshow(
self._make_im(kriging_weights_train, mask=self._train_mask),
norm=vnorm,
)
axes[1].plot(x_coord, y_coord, "r+")
axes[0].plot(x_coord, y_coord, "r+")

self._label_ax(axes[2], "Kriging Weights (nearest)")
axes[2].imshow(
self._label_ax(axes[1], "Kriging Weights (nearest)")
axes[1].imshow(
self._make_im(kriging_weights_nbrs, mask=nn_mask),
norm=vnorm,
)
axes[2].plot(x_coord, y_coord, "r+")
axes[1].plot(x_coord, y_coord, "r+")
fig.colorbar(im0, ax=axes.ravel().tolist())

plt.show()
Expand Down
116 changes: 6 additions & 110 deletions MuyGPyS/gp/deformation/anisotropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
#
# SPDX-License-Identifier: MIT

from typing import List, Tuple, Callable, Dict

import MuyGPyS._src.math as mm
from MuyGPyS._src.mpi_utils import mpi_chunk
from MuyGPyS._src.util import auto_str
from MuyGPyS.gp.deformation.deformation_fn import DeformationFn
from MuyGPyS.gp.deformation.metric import MetricFn
from MuyGPyS.gp.hyperparameter import ScalarParam
from MuyGPyS.gp.hyperparameter import VectorParam, NamedVectorParam


@auto_str
Expand All @@ -37,27 +35,10 @@ class Anisotropy(DeformationFn):
def __init__(
self,
metric: MetricFn,
**length_scales,
length_scale: VectorParam,
):
self.metric = metric
for i, key in enumerate(length_scales.keys()):
if key != "length_scale" + str(i):
raise ValueError(
"Anisotropic model expects one keyword argument for each "
"feature in the dataset labeled length_scale{i} for the "
"ith feature with indexing beginning at zero."
)
if not (
all(
isinstance(param, ScalarParam)
for param in length_scales.values()
)
):
raise ValueError(
"Anisotropic model expects all values for the length_scale{i} "
"keyword arguments to be of type ScalarParam."
)
self.length_scale = length_scales
self.length_scale = NamedVectorParam("length_scale", length_scale)

def __call__(self, dists: mm.ndarray, **length_scales) -> mm.ndarray:
"""
Expand All @@ -84,97 +65,12 @@ def __call__(self, dists: mm.ndarray, **length_scales) -> mm.ndarray:
`(data_count, nn_count, nn_count)` whose last two dimensions are
pairwise distance matrices.
"""
length_scale_array = self._length_scale_array(
dists.shape, **length_scales
)
return self.metric(dists / length_scale_array)

def _length_scale_array(
self, shape: mm.ndarray, **length_scales
) -> mm.ndarray:
if shape[-1] != len(self.length_scale):
if dists.shape[-1] != len(self.length_scale):
raise ValueError(
f"Difference tensor of shape {shape} must have final "
f"Difference tensor of shape {dists.shape} must have final "
f"dimension size of {len(self.length_scale)}"
)
return mm.array(
[
(
length_scales[key]
if key in length_scales.keys()
else self.length_scale[key]()
)
for key in self.length_scale
]
)

def get_opt_params(
self,
) -> Tuple[List[str], List[float], List[Tuple[float, float]]]:
"""
Report lists of unfixed hyperparameter names, values, and bounds.
Returns
-------
names:
A list of unfixed hyperparameter names.
params:
A list of unfixed hyperparameter values.
bounds:
A list of unfixed hyperparameter bound tuples.
"""
names: List[str] = []
params: List[float] = []
bounds: List[Tuple[float, float]] = []
for name, param in self.length_scale.items():
param.append_lists(name, names, params, bounds)
return names, params, bounds

def populate_length_scale(self, hyperparameters: Dict) -> None:
"""
Populates the hyperparameter dictionary of a KernelFn object with
`self.length_scales` of the Anisotropy object.
Args:
hyperparameters:
A dict containing the hyperparameters of a KernelFn object.
"""
for key, param in self.length_scale.items():
hyperparameters[key] = param

def embed_fn(self, fn: Callable) -> Callable:
"""
Augments a function to automatically apply the deformation to a
difference tensor.
Args:
fn:
A Callable with signature
`(diffs, *args, **kwargs) -> mm.ndarray` taking a difference
tensor `diffs` with shape `(..., feature_count)`.
Returns:
A new Callable that applies the deformation to `diffs`, removing
the last tensor dimension by collapsing the feature-wise differences
into scalar distances. Propagates any `length_scaleN` kwargs to the
deformation fn, making the function drivable by keyword
optimization.
"""

def embedded_fn(diffs, *args, length_scale=None, **kwargs):
length_scales = {
key: kwargs[key]
for key in kwargs
if key.startswith("length_scale")
}
kwargs = {
key: kwargs[key]
for key in kwargs
if not key.startswith("length_scale")
}
return fn(self(diffs, **length_scales), *args, **kwargs)

return embedded_fn
return self.metric(dists / self.length_scale(**length_scales))

@mpi_chunk(return_count=1)
def pairwise_tensor(
Expand Down
34 changes: 1 addition & 33 deletions MuyGPyS/gp/deformation/deformation_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SPDX-License-Identifier: MIT


from typing import Callable, Dict, List, Tuple
from typing import List, Tuple

import MuyGPyS._src.math as mm
from MuyGPyS.gp.deformation.metric import MetricFn
Expand Down Expand Up @@ -48,38 +48,6 @@ def get_opt_params(
"Cannot call DeformationFn base class functions!"
)

def populate_length_scale(self, hyperparameters: Dict) -> None:
"""
Populates the hyperparameter dictionary of a KernelFn object with any
parameters of the DeformationFn object.
Args:
hyperparameters:
A dict containing the hyperparameters of a KernelFn object.
"""
raise NotImplementedError(
"Cannot call DeformationFn base class functions!"
)

def embed_fn(self, fn: Callable) -> Callable:
"""
Augments a function to automatically apply the deformation to a
difference tensor.
Args:
fn:
A Callable with signature
`(diffs, *args, **kwargs) -> mm.ndarray` taking a difference
tensor `diffs`.
Returns:
A new Callable that applies the deformation to `diffs`, possibly
changing its tensor dimensionality.
"""
raise NotImplementedError(
"Cannot call DeformationFn base class functions!"
)

def pairwise_tensor(
self,
data: mm.ndarray,
Expand Down
85 changes: 24 additions & 61 deletions MuyGPyS/gp/deformation/isotropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# SPDX-License-Identifier: MIT


from typing import Callable, Dict, List, Optional, Tuple
from typing import Optional, Union

import MuyGPyS._src.math as mm
from MuyGPyS._src.mpi_utils import mpi_chunk
from MuyGPyS._src.util import auto_str
from MuyGPyS.gp.deformation.deformation_fn import DeformationFn
from MuyGPyS.gp.deformation.metric import MetricFn
from MuyGPyS.gp.hyperparameter import ScalarParam
from MuyGPyS.gp.hyperparameter import ScalarParam, NamedParam
from MuyGPyS.gp.hyperparameter.experimental import (
HierarchicalParam,
NamedHierarchicalParam,
)


@auto_str
Expand Down Expand Up @@ -39,16 +43,25 @@ def __init__(
metric: MetricFn,
length_scale: ScalarParam,
):
if not isinstance(length_scale, ScalarParam):
# This is brittle and should be refactored
if isinstance(length_scale, ScalarParam):
self.length_scale = NamedParam("length_scale", length_scale)
elif isinstance(length_scale, HierarchicalParam):
self.length_scale = NamedHierarchicalParam(
"length_scale", length_scale
)
else:
raise ValueError(
"Expected ScalarParam type for length_scale, not "
f"{type(length_scale)}"
)
self.length_scale = length_scale
self.metric = metric

def __call__(
self, dists: mm.ndarray, length_scale: Optional[float] = None, **kwargs
self,
dists: mm.ndarray,
length_scale: Optional[Union[float, mm.ndarray]] = None,
**kwargs,
) -> mm.ndarray:
"""
Apply isotropic deformation to an elementwise difference tensor.
Expand All @@ -69,64 +82,14 @@ def __call__(
pairwise distance matrices.
"""
if length_scale is None:
length_scale = self.length_scale()
length_scale = self.length_scale(**kwargs)
# This is brittle and I hate it. I'm not sure where to put this logic.
if isinstance(length_scale, mm.ndarray) and len(length_scale.shape) > 0:
shape = [None] * dists.ndim
shape[0] = slice(None)
length_scale = length_scale[tuple(shape)]
return self.metric.apply_length_scale(dists, length_scale)

def get_opt_params(
self,
) -> Tuple[List[str], List[float], List[Tuple[float, float]]]:
"""
Report lists of unfixed hyperparameter names, values, and bounds.
Returns
-------
names:
A list of unfixed hyperparameter names.
params:
A list of unfixed hyperparameter values.
bounds:
A list of unfixed hyperparameter bound tuples.
"""
names: List[str] = []
params: List[float] = []
bounds: List[Tuple[float, float]] = []
self.length_scale.append_lists("length_scale", names, params, bounds)
return names, params, bounds

def populate_length_scale(self, hyperparameters: Dict) -> None:
"""
Populates the hyperparameter dictionary of a KernelFn object with
`self.length_scale` of the Isotropy object.
Args:
hyperparameters:
A dict containing the hyperparameters of a KernelFn object.
"""
hyperparameters["length_scale"] = self.length_scale

def embed_fn(self, fn: Callable) -> Callable:
"""
Augments a function to automatically apply the deformation to a
difference tensor.
Args:
fn:
A Callable with signature
`(diffs, *args, **kwargs) -> mm.ndarray` taking a difference
tensor `diffs` with shape `(..., feature_count)`.
Returns:
A new Callable that applies the deformation to `diffs`, removing
the last tensor dimension by collapsing the feature-wise differences
into scalar distances. Also adds a `length_scale` kwarg, making the
function drivable by keyword optimization.
"""

def embedded_fn(dists, *args, length_scale=None, **kwargs):
return fn(self(dists, length_scale=length_scale), *args, **kwargs)

return embedded_fn

@mpi_chunk(return_count=1)
def pairwise_tensor(
self,
Expand Down
11 changes: 10 additions & 1 deletion MuyGPyS/gp/hyperparameter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
#
# SPDX-License-Identifier: MIT

from .scalar import Parameter, Parameter as ScalarParam
from .scalar import (
Parameter,
Parameter as ScalarParam,
NamedParameter as NamedParam,
)
from .vector import (
VectorParameter,
VectorParameter as VectorParam,
NamedVectorParameter as NamedVectorParam,
)
from .tensor import TensorParam
from .scale import AnalyticScale, DownSampleScale, FixedScale, ScaleFn
6 changes: 4 additions & 2 deletions MuyGPyS/gp/hyperparameter/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#
# SPDX-License-Identifier: MIT

from .hierarchical_nonstationary import (
HierarchicalNonstationaryHyperparameter,
from .hierarchical import (
HierarchicalParameter,
HierarchicalParameter as HierarchicalParam,
NamedHierarchicalParameter as NamedHierarchicalParam,
sample_knots,
)
Loading

0 comments on commit d1e82ce

Please sign in to comment.