Skip to content

Commit

Permalink
final updates for v0.7.0 (#185)
Browse files Browse the repository at this point in the history
* fixed symmetric log scaling, colormaps and equation formatting of loss tutorial

* improved readability and tidied torch tutorial

* Tidied up test sampler and vis function interfaces

* Tidied up the univariate regression tutorial

* Improved the presentation and prose of the anisotropic tutorial

* rewrote the fast regression tutorial for readability

* Updated docs for tensor maker functions.

* Updated docs for MuyGPyS.gp submodule

* Added pandas to docs requirements.

* Increment version to v0.7.0
  • Loading branch information
bwpriest authored Aug 25, 2023
1 parent d54b7db commit 8939344
Show file tree
Hide file tree
Showing 21 changed files with 911 additions and 852 deletions.
2 changes: 1 addition & 1 deletion MuyGPyS/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

"""Public MuyGPyS modules and functions."""

__version__ = "0.6.6"
__version__ = "0.7.0"

from MuyGPyS._src.config import (
config as config,
Expand Down
6 changes: 3 additions & 3 deletions MuyGPyS/_src/gp/tensors/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ def _l2(diffs: jnp.ndarray) -> jnp.ndarray:

@jit
def _fast_nn_update(
nn_indices: jnp.ndarray,
train_nn_indices: jnp.ndarray,
) -> jnp.ndarray:
train_count, _ = nn_indices.shape
train_count, _ = train_nn_indices.shape
new_nn_indices = jnp.concatenate(
(
jnp.expand_dims(jnp.arange(0, train_count), axis=1),
nn_indices[:, :-1],
train_nn_indices[:, :-1],
),
axis=1,
)
Expand Down
2 changes: 1 addition & 1 deletion MuyGPyS/_src/gp/tensors/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _pairwise_differences(points: np.ndarray) -> np.ndarray:


def _fast_nn_update(
nn_indices: np.ndarray,
train_nn_indices: np.ndarray,
) -> np.ndarray:
raise NotImplementedError(
'Function "muygps_fast_nn_update" does not support mpi!'
Expand Down
6 changes: 3 additions & 3 deletions MuyGPyS/_src/gp/tensors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ def _l2(diffs: np.ndarray) -> np.ndarray:


def _fast_nn_update(
nn_indices: np.ndarray,
train_nn_indices: np.ndarray,
) -> np.ndarray:
train_count, _ = nn_indices.shape
train_count, _ = train_nn_indices.shape
new_nn_indices = np.concatenate(
(
np.expand_dims(np.arange(0, train_count), axis=1),
nn_indices[:, :-1],
train_nn_indices[:, :-1],
),
axis=1,
)
Expand Down
6 changes: 3 additions & 3 deletions MuyGPyS/_src/gp/tensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def _l2(diffs: torch.ndarray) -> torch.ndarray:


def _fast_nn_update(
nn_indices: torch.ndarray,
train_nn_indices: torch.ndarray,
) -> torch.ndarray:
train_count, _ = nn_indices.shape
train_count, _ = train_nn_indices.shape
new_nn_indices = torch.cat(
(
torch.unsqueeze(torch.arange(0, train_count), dim=1),
nn_indices[:, :-1],
train_nn_indices[:, :-1],
),
dim=1,
)
Expand Down
Loading

0 comments on commit 8939344

Please sign in to comment.