Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle hierarchical smoothness for Matern #239

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MuyGPyS/gp/deformation/isotropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Isotropy(DeformationFn):
def __init__(
self,
metric: MetricFn,
length_scale: ScalarParam,
length_scale: Union[ScalarParam, HierarchicalParam],
):
# This is brittle and should be refactored
if isinstance(length_scale, ScalarParam):
Expand Down
13 changes: 12 additions & 1 deletion MuyGPyS/gp/hyperparameter/experimental/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def filter_kwargs(self, **kwargs) -> Tuple[Dict, Dict]:
lower[self._name] = self(kwargs["batch_features"], **params)
return lower, kwargs

def apply_fn(self, fn: Callable, name: str) -> Callable:
def apply_fn(self, fn: Callable) -> Callable:
def applied_fn(*args, **kwargs):
lower, kwargs = self.filter_kwargs(**kwargs)
return fn(*args, **lower, **kwargs)
Expand Down Expand Up @@ -157,6 +157,17 @@ def append_lists(
def populate(self, hyperparameters: Dict) -> None:
self._params.populate(hyperparameters)

def fixed(self) -> bool:
"""
Report whether the parameter is fixed, and is to be ignored during
optimization.

Returns:
`True` if fixed, `False` otherwise.
"""
# return self._params.fixed()
return False


class NamedHierarchicalVectorParameter(NamedVectorParam):
def __init__(self, name: str, param: VectorParam):
Expand Down
2 changes: 1 addition & 1 deletion MuyGPyS/gp/hyperparameter/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def fixed(self) -> bool:
Returns:
`True` if fixed, `False` otherwise.
"""
return mm.all(param._fixed for param in self._params)
return mm.all([param._fixed for param in self._params])


class NamedVectorParameter(VectorParameter):
Expand Down
24 changes: 18 additions & 6 deletions MuyGPyS/gp/kernels/matern.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
>>> Kcross = kern(crosswise_diffs)
"""

from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union

import MuyGPyS._src.math as mm
from MuyGPyS._src.gp.kernels import (
Expand All @@ -55,18 +55,22 @@
l2,
)
from MuyGPyS.gp.hyperparameter import ScalarParam, NamedParam
from MuyGPyS.gp.hyperparameter.experimental import (
HierarchicalParam,
NamedHierarchicalParam,
)
from MuyGPyS.gp.kernels import KernelFn


def _set_matern_fn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of the problem is that hierarchical parameters require batch_features to be evaluated, and we don't have them at init time. I'm not sure if we should just default to _backend_gen_fn or if we need to evaluate and map _set_matern_fn to each smoothness value.

smoothness: ScalarParam,
smoothness: Union[NamedParam, NamedHierarchicalParam],
_backend_05_fn: Callable = _matern_05_fn,
_backend_15_fn: Callable = _matern_15_fn,
_backend_25_fn: Callable = _matern_25_fn,
_backend_inf_fn: Callable = _matern_inf_fn,
_backend_gen_fn: Callable = _matern_gen_fn,
):
if smoothness.fixed() is True:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was crazy. Numpy functions return np.bool_ so comparisons with == work but is fail. That took me a while to debug.

if smoothness.fixed():
if smoothness() == 0.5:
return _backend_05_fn
elif smoothness() == 1.5:
Expand Down Expand Up @@ -119,17 +123,25 @@ class Matern(KernelFn):

def __init__(
self,
smoothness: ScalarParam = ScalarParam(0.5),
smoothness: Union[ScalarParam, HierarchicalParam] = ScalarParam(0.5),
deformation: DeformationFn = Isotropy(
l2, length_scale=ScalarParam(1.0)
),
_backend_ones: Callable = mm.ones,
_backend_zeros: Callable = mm.zeros,
_backend_squeeze: Callable = mm.squeeze,
**_backend_fns
**_backend_fns,
):
super().__init__(deformation=deformation)
self.smoothness = NamedParam("smoothness", smoothness)
if isinstance(smoothness, ScalarParam):
self.smoothness = NamedParam("smoothness", smoothness)
elif isinstance(smoothness, HierarchicalParam):
self.smoothness = NamedHierarchicalParam("smoothness", smoothness)
else:
raise ValueError(
"Expected ScalarParam type for smoothness, not "
f"{type(smoothness)}"
)
self._backend_ones = _backend_ones
self._backend_zeros = _backend_zeros
self._backend_squeeze = _backend_squeeze
Expand Down
92 changes: 92 additions & 0 deletions tests/experimental/nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,98 @@ def test_hierarchical_nonstationary_rbf(
shape=(batch_count, nn_count, nn_count),
)

@parameterized.parameters(
(
(
feature_count,
type(high_level_kernel).__name__,
smoothness,
deformation,
)
for feature_count in [2, 17]
for knot_count in [10]
for knot_features in [
sample_knots(feature_count=feature_count, knot_count=knot_count)
]
for knot_values in [
VectorParameter(*[Parameter(i) for i in range(knot_count)]),
]
for high_level_kernel in [RBF(), Matern()]
for smoothness, deformation in [
(
Parameter(1.5),
Isotropy(
l2,
length_scale=Parameter(1),
),
),
(
HierarchicalParameter(
knot_features, knot_values, high_level_kernel
),
Isotropy(
l2,
length_scale=Parameter(1),
),
),
(
Parameter(1.5),
Isotropy(
l2,
length_scale=HierarchicalParameter(
knot_features, knot_values, high_level_kernel
),
),
),
]
)
)
def test_hierarchical_nonstationary_matern(
self,
feature_count,
high_level_kernel,
smoothness,
deformation,
):
muygps = MuyGPS(
kernel=Matern(smoothness=smoothness, deformation=deformation),
)

# prepare data
data_count = 1000
data = _make_gaussian_dict(
data_count=data_count,
feature_count=feature_count,
response_count=1,
)

# neighbors and differences
nn_count = 30
nbrs_lookup = NN_Wrapper(
data["input"], nn_count, nn_method="exact", algorithm="ball_tree"
)
batch_count = 200
batch_indices, batch_nn_indices = sample_batch(
nbrs_lookup, batch_count, data_count
)
(_, pairwise_diffs, _, _) = muygps.make_train_tensors(
batch_indices,
batch_nn_indices,
data["input"],
data["output"],
)

batch_features = batch_features_tensor(data["input"], batch_indices)

Kin = muygps.kernel(pairwise_diffs, batch_features=batch_features)

_check_ndarray(
self.assertEqual,
Kin,
mm.ftype,
shape=(batch_count, nn_count, nn_count),
)


if __name__ == "__main__":
absltest.main()
Loading