Skip to content

Commit

Permalink
Merge branch 'main' into interrogate-workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
slinderman authored Jan 27, 2025
2 parents 01142af + dfc2fa2 commit b7fc12a
Show file tree
Hide file tree
Showing 48 changed files with 1,184 additions and 746 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
dynamax/_version.py export-subst
*.ipynb linguist-documentation
4 changes: 2 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- name: Install dependencies
run: |
pip install -e '.[dev]'
pip install -e '.[test]'
- name: Run tests
run: pytest
run: pytest --cov=./
3 changes: 2 additions & 1 deletion dynamax/generalized_gaussian_ssm/dekf/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
The diagonal EKF code (used in our paper https://openreview.net/pdf?id=asgeEt25kk)
has moved to https://github.com/probml/dynamax/tree/main/dynamax/rebayes.
has moved to https://github.com/probml/dynamax

7 changes: 3 additions & 4 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax import jacfwd, vmap, lax
import jax.numpy as jnp
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable

Expand Down Expand Up @@ -83,7 +82,7 @@ def compute_weights_and_sigmas(self, m, P):


def _predict(m, P, f, Q, u, g_ev, g_cov):
"""Predict next mean and covariance under an additive-noise Gaussian filter
r"""Predict next mean and covariance under an additive-noise Gaussian filter
p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
where
Expand Down Expand Up @@ -117,7 +116,7 @@ def _predict(m, P, f, Q, u, g_ev, g_cov):


def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist):
"""Condition a Gaussian potential on a new observation with arbitrary
r"""Condition a Gaussian potential on a new observation with arbitrary
likelihood with given functions for conditional moments and make a
Gaussian approximation.
p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
Expand Down Expand Up @@ -172,7 +171,7 @@ def _step(carry, _):


def _statistical_linear_regression(mu, Sigma, m, S, C):
"""Return moment-matching affine coefficients and approximation noise variance
r"""Return moment-matching affine coefficients and approximation noise variance
given joint moments.
g(x) \approx Ax + b + e where e ~ N(0, Omega)
Expand Down
20 changes: 10 additions & 10 deletions dynamax/generalized_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from dynamax.nonlinear_gaussian_ssm.models import FnStateToState, FnStateAndInputToState
from dynamax.nonlinear_gaussian_ssm.models import FnStateToEmission, FnStateAndInputToEmission

FnStateToEmission2 = Callable[[Float[Array, "state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateToEmission2 = Callable[[Float[Array, " state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, "emission_dim emission_dim"]]

# emission distribution takes a mean vector and covariance matrix and returns a distribution
EmissionDistFn = Callable[ [Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]
EmissionDistFn = Callable[ [Float[Array, " state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]


class ParamsGGSSM(NamedTuple):
Expand All @@ -42,7 +42,7 @@ class ParamsGGSSM(NamedTuple):
"""

initial_mean: Float[Array, "state_dim"]
initial_mean: Float[Array, " state_dim"]
initial_covariance: Float[Array, "state_dim state_dim"]
dynamics_function: Union[FnStateToState, FnStateAndInputToState]
dynamics_covariance: Float[Array, "state_dim state_dim"]
Expand Down Expand Up @@ -97,15 +97,15 @@ def covariates_shape(self):
def initial_distribution(
self,
params: ParamsGGSSM,
inputs: Optional[Float[Array, "input_dim"]]=None
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
return MVN(params.initial_mean, params.initial_covariance)

def transition_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
f = params.dynamics_function
if inputs is None:
Expand All @@ -117,8 +117,8 @@ def transition_distribution(
def emission_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
h = params.emission_mean_function
R = params.emission_cov_function
Expand All @@ -128,4 +128,4 @@ def emission_distribution(
else:
mean = h(state, inputs)
cov = R(state, inputs)
return params.emission_dist(mean, cov)
return params.emission_dist(mean, cov)
4 changes: 2 additions & 2 deletions dynamax/generalized_gaussian_ssm/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_poisson_emission(key, kwargs):
keys = jr.split(key, 3)
state_dim = kwargs['state_dim']
emission_dim = 1 # Univariate Poisson
poisson_weights = jr.normal(keys[0], shape=(emission_dim, state_dim))
poisson_weights = jr.normal(keys[0], shape=(emission_dim, state_dim)) / jnp.sqrt(state_dim)
model = GeneralizedGaussianSSM(state_dim, emission_dim)

# Define model parameters with Poisson emission
Expand Down Expand Up @@ -51,7 +51,7 @@ def test_poisson_emission(key, kwargs):

# Fit model with Gaussian emission
gaussian_marginal_lls = conditional_moments_gaussian_filter(gaussian_params, EKFIntegrals(), emissions).marginal_loglik

# Check that the marginal log-likelihoods under Poisson emission are higher
assert pois_marginal_lls > gaussian_marginal_lls

Expand Down
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
plt.figure()
plt.imshow(jnp.vstack((states[None, :], most_likely_states[None, :])),
aspect="auto", interpolation='none', cmap="Greys")
plt.yticks([0.0, 1.0], ["$z$", "$\hat{z}$"])
plt.yticks([0.0, 1.0], ["$z$", r"$\hat{z}$"])
plt.xlabel("time")
plt.xlim(0, 500)


print("true log prob: ", hmm.marginal_log_prob(true_params, emissions, inputs=inputs))
print("test log prob: ", test_hmm.marginal_log_prob(params, emissions, inputs=inputs))

plt.show()
plt.show()
Loading

0 comments on commit b7fc12a

Please sign in to comment.