Skip to content

Commit

Permalink
implementation of experimental shear kernel (#188)
Browse files Browse the repository at this point in the history
* moved tests of experimental components into their own subdirectory

* Added 'prod' to library of wrapped math functions

* Added shear kernel API to MuyGPyS.gp.kernels.experimental

* added initial implementation of shear kernels

* added experimental/shear.ipynb notebook validating shear kernel. Timings suggest a >2 OOM improvement in runtime.

* added shear kernel test script referencing Bob's implementation

* Added JAX and torch implementations of the analytic shear kernel
  • Loading branch information
bwpriest authored Sep 12, 2023
1 parent 8939344 commit 635f18c
Show file tree
Hide file tree
Showing 17 changed files with 1,139 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/develop-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ jobs:
python tests/gp.py
python tests/batch.py
python tests/predict.py
python tests/nonstationary.py
python tests/experimental/nonstationary.py
python tests/precompute/fast_posterior_mean.py
- name: Optimize Tests
if: matrix.test-group == 'optimize'
run: python tests/optimize.py
- name: Optimize Tests - experimental
if: matrix.test-group == 'optimize-experimental'
run: python tests/mini_batch.py
run: python tests/experimental/mini_batch.py
- name: Multivariate Tests
if: matrix.test-group == 'multivariate'
run: python tests/multivariate.py
Expand Down
11 changes: 11 additions & 0 deletions MuyGPyS/_src/gp/kernels/shear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

from MuyGPyS._src.util import _collect_implementation

# _collect_implementation returns a tuple, so need to subscript to get singleton
_shear_fn = _collect_implementation(
"MuyGPyS._src.gp.kernels.shear", "_shear_fn"
)[0]
226 changes: 226 additions & 0 deletions MuyGPyS/_src/gp/kernels/shear/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT


from jax import jit

import MuyGPyS._src.math.jax as jnp
import MuyGPyS._src.math.numpy as np


@jit
def _kk_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
prod_sq_diffs,
sum_quad_diffs,
a=1,
length_scale=1,
):
return (
1
/ 4
* (
a
* (
8 * length_scale**2
- 8 * length_scale * sum_sq_diffs
+ 2 * prod_sq_diffs
+ sum_quad_diffs
)
* exp_inv_scaled_sum_sq_diffs
/ length_scale**4
)
)


@jit
def _kg1_fn(
exp_inv_scaled_sum_sq_diffs,
diff_xy_quad_diffs,
diff_yx_sq_diffs,
a=1,
length_scale=1,
):
return (
1
/ 4
* (
a
* (6 * length_scale * diff_yx_sq_diffs + diff_xy_quad_diffs)
* exp_inv_scaled_sum_sq_diffs
/ length_scale**4
)
)


@jit
def _kg2_fn(
exp_inv_scaled_sum_sq_diffs, sum_sq_diffs, prod_diffs, a=1, length_scale=1
):
return (
1
/ 4
* (
2
* a
* prod_diffs
* (-6 * length_scale + sum_sq_diffs)
* exp_inv_scaled_sum_sq_diffs
/ length_scale**4
)
)


@jit
def _g1g1_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
sum_quad_diffs,
prod_sq_diffs,
a=1,
length_scale=1,
):
return (
1
/ 4
* (
a
* (
4 * length_scale**2
- 4 * length_scale * sum_sq_diffs
- 2 * prod_sq_diffs
+ sum_quad_diffs
)
* exp_inv_scaled_sum_sq_diffs
/ length_scale**4
)
)


@jit
def _g1g2_fn(
exp_inv_scaled_sum_sq_diffs,
diff_xy_sq_diffs,
prod_diffs,
a=1,
length_scale=1,
):
return (
1
/ 4
* (
2
* a
* prod_diffs
* diff_xy_sq_diffs
* exp_inv_scaled_sum_sq_diffs
/ length_scale**4
)
)


@jit
def _g2g2_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
prod_sq_diffs,
a=1,
length_scale=1,
):
return (
1
/ 4
* (
4
* a
* (length_scale**2 - length_scale * sum_sq_diffs + prod_sq_diffs)
* exp_inv_scaled_sum_sq_diffs
/ length_scale**4
)
)


# compute the full covariance matrix
@jit
def _shear_fn(diffs, a=1, length_scale=1):
shape = np.array(diffs.shape[:-1], dtype=int)
shape[-1] *= 3
shape[-2] *= 3
full_m = jnp.zeros(shape)

# compute intermediate difference tensors once here
prod_diffs = jnp.prod(diffs, axis=-1)
sq_diffs = diffs**2
quad_diffs = sq_diffs**2
sum_sq_diffs = jnp.sum(sq_diffs, axis=-1)
prod_sq_diffs = jnp.prod(sq_diffs, axis=-1)
sum_quad_diffs = jnp.sum(quad_diffs, axis=-1)
diff_yx_sq_diffs = sq_diffs[..., 1] - sq_diffs[..., 0]
diff_xy_sq_diffs = sq_diffs[..., 0] - sq_diffs[..., 1]
diff_xy_quad_diffs = quad_diffs[..., 0] - quad_diffs[..., 1]
exp_inv_scaled_sum_sq_diffs = jnp.exp(-sum_sq_diffs / (2 * length_scale))

full_m = full_m.at[..., 0::3, 0::3].set(
_kk_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
prod_sq_diffs,
sum_quad_diffs,
a,
length_scale,
)
)
full_m = full_m.at[..., 0::3, 1::3].set(
_kg1_fn(
exp_inv_scaled_sum_sq_diffs,
diff_xy_quad_diffs,
diff_yx_sq_diffs,
a,
length_scale,
)
)
full_m = full_m.at[..., 1::3, 0::3].set(full_m[..., 0::3, 1::3])
full_m = full_m.at[..., 0::3, 2::3].set(
_kg2_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
prod_diffs,
a,
length_scale,
)
)
full_m = full_m.at[..., 2::3, 0::3].set(full_m[..., 0::3, 2::3])
full_m = full_m.at[..., 1::3, 1::3].set(
_g1g1_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
sum_quad_diffs,
prod_sq_diffs,
a,
length_scale,
)
)
full_m = full_m.at[..., 1::3, 2::3].set(
_g1g2_fn(
exp_inv_scaled_sum_sq_diffs,
diff_xy_sq_diffs,
prod_diffs,
a,
length_scale,
)
)
full_m = full_m.at[..., 2::3, 1::3].set(full_m[..., 1::3, 2::3])
full_m = full_m.at[..., 2::3, 2::3].set(
_g2g2_fn(
exp_inv_scaled_sum_sq_diffs,
sum_sq_diffs,
prod_sq_diffs,
a,
length_scale,
)
)

return full_m
6 changes: 6 additions & 0 deletions MuyGPyS/_src/gp/kernels/shear/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: MIT

from MuyGPyS._src.gp.kernels.shear import _shear_fn
Loading

0 comments on commit 635f18c

Please sign in to comment.