Skip to content

Commit

Permalink
Revert conditional Aesara imports
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Apr 1, 2022
1 parent d16c55d commit b8335d8
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 210 deletions.
65 changes: 28 additions & 37 deletions pymc3_hmm/distributions.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
import warnings

import numpy as np

try: # pragma: no cover
import aesara
import aesara.tensor as at
from aesara.graph.op import get_test_value
from aesara.scalar import upcast
from aesara.tensor.extra_ops import broadcast_to as at_broadcast_to
except ImportError: # pragma: no cover
import theano as aesara
import theano.tensor as at
from theano.graph.op import get_test_value
from theano.scalar import upcast
from theano.tensor.extra_ops import broadcast_to as at_broadcast_to

import pymc3 as pm
import theano
import theano.tensor as tt
from pymc3.distributions.distribution import (
Distribution,
_DrawValuesContext,
draw_values,
generate_samples,
)
from pymc3.distributions.mixture import _conversion_map, all_discrete
from theano.graph.op import get_test_value
from theano.scalar import upcast
from theano.tensor.extra_ops import broadcast_to as tt_broadcast_to

from pymc3_hmm.utils import tt_broadcast_arrays, vsearchsorted

Expand Down Expand Up @@ -62,7 +53,7 @@ def distribution_subset_args(dist, shape, idx):
res = dict()
for param in dist_param_names:

bcast_res = at_broadcast_to(getattr(dist, param), shape)
bcast_res = tt_broadcast_to(getattr(dist, param), shape)

res[param] = bcast_res[idx]

Expand Down Expand Up @@ -110,7 +101,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
equal to the size of `comp_dists`.
"""
self.states = at.as_tensor_variable(pm.intX(states))
self.states = tt.as_tensor_variable(pm.intX(states))

if len(comp_dists) > 31:
warnings.warn(
Expand All @@ -136,7 +127,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
bcast_means = tt_broadcast_arrays(
*([self.states] + [d.mean.astype(dtype) for d in self.comp_dists])
)
self.mean = at.choose(self.states, bcast_means[1:])
self.mean = tt.choose(self.states, bcast_means[1:])

if "mean" not in defaults:
defaults.append("mean")
Expand All @@ -148,7 +139,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
bcast_modes = tt_broadcast_arrays(
*([self.states] + [d.mode.astype(dtype) for d in self.comp_dists])
)
self.mode = at.choose(self.states, bcast_modes[1:])
self.mode = tt.choose(self.states, bcast_modes[1:])

if "mode" not in defaults:
defaults.append("mode")
Expand All @@ -161,16 +152,16 @@ def __init__(self, comp_dists, states, *args, **kwargs):
def logp(self, obs):
"""Return the Theano log-likelihood at a point."""

obs_tt = at.as_tensor_variable(obs)
obs_tt = tt.as_tensor_variable(obs)

logp_val = at.alloc(-np.inf, *obs.shape)
logp_val = tt.alloc(-np.inf, *obs.shape)

for i, dist in enumerate(self.comp_dists):
i_mask = at.eq(self.states, i)
i_mask = tt.eq(self.states, i)
obs_i = obs_tt[i_mask]
subset_dist_kwargs = distribution_subset_args(dist, obs.shape, i_mask)
subset_dist = dist.dist(**subset_dist_kwargs)
logp_val = at.set_subtensor(logp_val[i_mask], subset_dist.logp(obs_i))
logp_val = tt.set_subtensor(logp_val[i_mask], subset_dist.logp(obs_i))

return logp_val

Expand Down Expand Up @@ -255,8 +246,8 @@ def __init__(self, mu=None, states=None, **kwargs):
A vector of integer 0-1 states that indicate which component of
the mixture is active at each point/time.
"""
self.mu = at.as_tensor_variable(pm.floatX(mu))
self.states = at.as_tensor_variable(states)
self.mu = tt.as_tensor_variable(pm.floatX(mu))
self.states = tt.as_tensor_variable(states)

super().__init__(
[Constant.dist(np.array(0, dtype=np.int64)), pm.Poisson.dist(mu)],
Expand Down Expand Up @@ -292,15 +283,15 @@ def __init__(self, Gammas, gamma_0, shape, **kwargs):
Shape of the state sequence. The last dimension is `N`, i.e. the
length of the state sequence(s).
"""
self.gamma_0 = at.as_tensor_variable(pm.floatX(gamma_0))
self.gamma_0 = tt.as_tensor_variable(pm.floatX(gamma_0))

assert Gammas.ndim >= 3

self.Gammas = at.as_tensor_variable(pm.floatX(Gammas))
self.Gammas = tt.as_tensor_variable(pm.floatX(Gammas))

shape = np.atleast_1d(shape)

dtype = _conversion_map[aesara.config.floatX]
dtype = _conversion_map[theano.config.floatX]
self.mode = np.zeros(tuple(shape), dtype=dtype)

super().__init__(shape=shape, **kwargs)
Expand All @@ -324,23 +315,23 @@ def logp(self, states):
""" # noqa: E501

states_tt = at.as_tensor(states)
states_tt = tt.as_tensor(states)

if states.ndim > 1 or self.Gammas.ndim > 3 or self.gamma_0.ndim > 1:
raise NotImplementedError("Broadcasting not supported.")

Gammas_tt = at_broadcast_to(
Gammas_tt = tt_broadcast_to(
self.Gammas, (states.shape[0],) + tuple(self.Gammas.shape)[-2:]
)
gamma_0_tt = self.gamma_0

Gamma_1_tt = Gammas_tt[0]
P_S_1_tt = at.dot(gamma_0_tt, Gamma_1_tt)[states_tt[0]]
P_S_1_tt = tt.dot(gamma_0_tt, Gamma_1_tt)[states_tt[0]]

# def S_logp_fn(S_tm1, S_t, Gamma):
# return at.log(Gamma[..., S_tm1, S_t])
# return tt.log(Gamma[..., S_tm1, S_t])
#
# P_S_2T_tt, _ = aesara.scan(
# P_S_2T_tt, _ = theano.scan(
# S_logp_fn,
# sequences=[
# {
Expand All @@ -350,10 +341,10 @@ def logp(self, states):
# Gammas_tt,
# ],
# )
P_S_2T_tt = Gammas_tt[at.arange(1, states.shape[0]), states[:-1], states[1:]]
P_S_2T_tt = Gammas_tt[tt.arange(1, states.shape[0]), states[:-1], states[1:]]

log_P_S_1T_tt = at.concatenate(
[at.shape_padright(at.log(P_S_1_tt)), at.log(P_S_2T_tt)]
log_P_S_1T_tt = tt.concatenate(
[tt.shape_padright(tt.log(P_S_1_tt)), tt.log(P_S_2T_tt)]
)

res = log_P_S_1T_tt.sum()
Expand Down Expand Up @@ -424,7 +415,7 @@ class Constant(Distribution):

def __init__(self, c, shape=(), defaults=("mode",), **kwargs):

c = at.as_tensor_variable(c)
c = tt.as_tensor_variable(c)

dtype = c.dtype

Expand All @@ -448,7 +439,7 @@ def _random(c, dtype=self.dtype, size=None):
)

def logp(self, value):
return at.switch(at.eq(value, self.c), 0.0, -np.inf)
return tt.switch(tt.eq(value, self.c), 0.0, -np.inf)

def _distr_parameters_for_repr(self):
return ["c"]
Expand Down
75 changes: 28 additions & 47 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,28 @@
from functools import singledispatch
from itertools import chain
from typing import Callable, Tuple

import numpy as np

try: # pragma: no cover
import aesara.scalar as aes
import aesara.tensor as at
from aesara import config
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.scalar.basic import Dot
from aesara.sparse.basic import StructuredDot
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant
except ImportError: # pragma: no cover
import theano.scalar as aes
import theano.tensor as at
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant
from theano.tensor.basic import Dot
from theano.sparse.basic import StructuredDot
from theano import config

from functools import singledispatch

import pymc3 as pm
import scipy
import theano.scalar as ts
import theano.tensor as tt
from pymc3.distributions.distribution import draw_values
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name
from scipy.stats import invgamma
from theano import config
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.sparse.basic import StructuredDot
from theano.tensor.basic import Dot
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

from pymc3_hmm.distributions import DiscreteMarkovChain, HorseShoe, SwitchingProcess
from pymc3_hmm.utils import compute_trans_freqs
Expand Down Expand Up @@ -185,15 +166,15 @@ def __init__(self, vars, values=None, model=None):
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))

comp_logp_stacked = at.stack(comp_logps)
comp_logp_stacked = tt.stack(comp_logps)
else:
raise TypeError(
"This sampler only supports `SwitchingProcess` observations"
)

dep_comps_logp_stacked.append(comp_logp_stacked)

comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0)
comp_logp_stacked = tt.sum(dep_comps_logp_stacked, axis=0)

(M,) = draw_values([var.distribution.gamma_0.shape[-1]], point=model.test_point)
N = model.test_point[var.name].shape[-1]
Expand Down Expand Up @@ -352,9 +333,9 @@ def _set_row_mappings(self, Gamma, dir_priors, model):
Gamma = pre_greedy_local_optimizer(
FunctionGraph([], []),
[
OpRemove(Elemwise(aes.Cast(aes.float32))),
OpRemove(Elemwise(aes.Cast(aes.float64))),
OpRemove(Elemwise(aes.identity)),
OpRemove(Elemwise(ts.Cast(ts.float32))),
OpRemove(Elemwise(ts.Cast(ts.float64))),
OpRemove(Elemwise(ts.identity)),
],
Gamma,
)
Expand All @@ -378,7 +359,7 @@ def _set_row_mappings(self, Gamma, dir_priors, model):

Gamma_Join = Gamma_DimShuffle.inputs[0].owner

if not (isinstance(Gamma_Join.op, at.basic.Join)):
if not (isinstance(Gamma_Join.op, tt.basic.Join)):
raise TypeError(
"The transition matrix should be comprised of stacked row vectors"
)
Expand Down Expand Up @@ -546,7 +527,7 @@ def hs_regression_model_Normal(dist, rv, model):
mu = dist.mu
y_X_fn = None
if hasattr(rv, "observations"):
obs = at.as_tensor_variable(rv.observations)
obs = tt.as_tensor_variable(rv.observations)
obs_fn = model.fn(obs)

def y_X_fn(points, X):
Expand All @@ -558,22 +539,22 @@ def y_X_fn(points, X):
@hs_regression_model.register(pm.NegativeBinomial)
def hs_regression_model_NegativeBinomial(dist, rv, model):

mu = at.as_tensor_variable(dist.mu)
mu = tt.as_tensor_variable(dist.mu)

if mu.owner and mu.owner.op == at.exp:
if mu.owner and mu.owner.op == tt.exp:
eta = mu.owner.inputs[0]
else:
eta = mu

alpha = at.as_tensor_variable(dist.alpha)
alpha = tt.as_tensor_variable(dist.alpha)
if hasattr(rv, "observations"):
from polyagamma import random_polyagamma

obs = at.as_tensor_variable(rv.observations)
obs = tt.as_tensor_variable(rv.observations)
h_z_alpha_fn = model.fn(
[
alpha + obs,
eta.squeeze() - at.log(alpha),
eta.squeeze() - tt.log(alpha),
alpha,
obs,
]
Expand Down Expand Up @@ -617,7 +598,7 @@ def find_dot(node, beta, model, y_fn):
return node, X_fn, y_fn
else:
# if exp transformation
if isinstance(node.owner.op, at.elemwise.Elemwise):
if isinstance(node.owner.op, tt.elemwise.Elemwise):
res = find_dot(node.owner.inputs[0], beta, model, y_fn)
if res:
node, X_fn, _ = res
Expand Down
Loading

0 comments on commit b8335d8

Please sign in to comment.