Skip to content

Commit

Permalink
merge resolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Mar 13, 2024
1 parent 101817a commit d6d0191
Show file tree
Hide file tree
Showing 11 changed files with 872 additions and 59 deletions.
13 changes: 13 additions & 0 deletions thejoker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
"phase_coverage_per_period",
]

# SB2:
from .thejoker_sb2 import *
from .prior_sb2 import JokerSB2Prior


__bibtex__ = __citation__ = """@ARTICLE{thejoker,
author = {{Price-Whelan}, Adrian M. and {Hogg}, David W. and
Expand All @@ -55,3 +59,12 @@
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
"""

__all__ = [
'TheJoker',
'RVData',
'JokerSamples',
'JokerPrior',
'plot_rv_curves',
'TheJokerSB2'
]
23 changes: 13 additions & 10 deletions thejoker/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ class RVData:
(days). Set to ``False`` to disable subtracting the reference time.
clean : bool (optional)
Filter out any NaN or Inf data points.
sort : bool (optional)
Whether or not to sort on time.
"""

@u.quantity_input(rv=u.km / u.s, rv_err=[u.km / u.s, (u.km / u.s) ** 2])
def __init__(self, t, rv, rv_err, t_ref=None, clean=True):
def __init__(self, t, rv, rv_err, t_ref=None, clean=True, sort=True):
# For speed, time is saved internally as BMJD:
if isinstance(t, Time):
_t_bmjd = t.tcb.mjd
Expand Down Expand Up @@ -94,15 +96,16 @@ def __init__(self, t, rv, rv_err, t_ref=None, clean=True):
else:
self.rv_err = self.rv_err[idx]

# sort on times
idx = self._t_bmjd.argsort()
self._t_bmjd = self._t_bmjd[idx]
self.rv = self.rv[idx]
if self._has_cov:
self.rv_err = self.rv_err[idx]
self.rv_err = self.rv_err[:, idx]
else:
self.rv_err = self.rv_err[idx]
if sort:
# sort on times
idx = self._t_bmjd.argsort()
self._t_bmjd = self._t_bmjd[idx]
self.rv = self.rv[idx]
if self._has_cov:
self.rv_err = self.rv_err[idx]
self.rv_err = self.rv_err[:, idx]
else:
self.rv_err = self.rv_err[idx]

if t_ref is False:
self.t_ref = None
Expand Down
15 changes: 12 additions & 3 deletions thejoker/likelihood_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ def marginal_ln_likelihood_inmem(joker_helper, prior_samples_batch):
return np.array(ll)


def make_full_samples_inmem(joker_helper, prior_samples_batch, rng, n_linear_samples=1):
from .samples import JokerSamples
def make_full_samples_inmem(
joker_helper, prior_samples_batch, rng, n_linear_samples=1, SamplesCls=None
):
if SamplesCls is None:
from .samples import JokerSamples

SamplesCls = JokerSamples

if prior_samples_batch.dtype != np.float64:
prior_samples_batch = prior_samples_batch.astype(np.float64)
Expand All @@ -77,7 +82,7 @@ def make_full_samples_inmem(joker_helper, prior_samples_batch, rng, n_linear_sam
)

# unpack the raw samples
samples = JokerSamples.unpack(
samples = SamplesCls.unpack(
raw_samples,
joker_helper.internal_units,
t_ref=joker_helper.data.t_ref,
Expand All @@ -96,6 +101,7 @@ def rejection_sample_inmem(
max_posterior_samples=None,
n_linear_samples=1,
return_all_logprobs=False,
SamplesCls=None,
):
if max_posterior_samples is None:
max_posterior_samples = len(prior_samples_batch)
Expand All @@ -114,6 +120,7 @@ def rejection_sample_inmem(
prior_samples_batch[good_samples_idx],
rng,
n_linear_samples=n_linear_samples,
SamplesCls=SamplesCls,
)

if ln_prior is not None and ln_prior is not False:
Expand All @@ -136,6 +143,7 @@ def iterative_rejection_inmem(
init_batch_size=None,
growth_factor=128,
n_linear_samples=1,
SamplesCls=None,
):
n_total_samples = len(prior_samples_batch)

Expand Down Expand Up @@ -219,6 +227,7 @@ def iterative_rejection_inmem(
prior_samples_batch[full_samples_idx],
rng,
n_linear_samples=n_linear_samples,
SamplesCls=SamplesCls,
)

# FIXME: copy-pasted from function above
Expand Down
9 changes: 7 additions & 2 deletions thejoker/multiproc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def make_full_samples(
samples_idx,
n_linear_samples=1,
n_batches=None,
SamplesCls=JokerSamples,
):
task_args = (prior_samples_file, joker_helper, n_linear_samples)
results = run_worker(
Expand All @@ -164,14 +165,14 @@ def make_full_samples(
task_args=task_args,
n_batches=n_batches,
samples_idx=samples_idx,
rng=rng,
random_state=rng,
)

# Concatenate all of the raw samples arrays
raw_samples = np.concatenate(results)

# unpack the raw samples
samples = JokerSamples.unpack(
samples = SamplesCls.unpack(
raw_samples,
joker_helper.internal_units,
t_ref=joker_helper.data.t_ref,
Expand All @@ -195,6 +196,7 @@ def rejection_sample_helper(
n_batches=None,
randomize_prior_order=False,
return_all_logprobs=False,
SamplesCls=None,
):
# Total number of samples in the cache:
with tb.open_file(prior_samples_file, mode="r") as f:
Expand Down Expand Up @@ -271,6 +273,7 @@ def rejection_sample_helper(
full_samples_idx,
n_linear_samples=n_linear_samples,
n_batches=n_batches,
SamplesCls=SamplesCls,
)

if return_logprobs:
Expand Down Expand Up @@ -300,6 +303,7 @@ def iterative_rejection_helper(
return_logprobs=False,
n_batches=None,
randomize_prior_order=False,
SamplesCls=None,
):
# Total number of samples in the cache:
with tb.open_file(prior_samples_file, mode="r") as f:
Expand Down Expand Up @@ -412,6 +416,7 @@ def iterative_rejection_helper(
full_samples_idx,
n_linear_samples=n_linear_samples,
n_batches=n_batches,
SamplesCls=SamplesCls,
)

# FIXME: copy-pasted from function above
Expand Down
100 changes: 64 additions & 36 deletions thejoker/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def _validate_model(model):


class JokerPrior:
_sb2 = False

def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
"""
This class controls the prior probability distributions for the
Expand Down Expand Up @@ -121,7 +123,9 @@ def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
# are only used to validate that the units for each parameter are
# equivalent to these
self._nonlinear_equiv_units = get_nonlinear_equiv_units()
self._linear_equiv_units = get_linear_equiv_units(self.poly_trend)
self._linear_equiv_units = get_linear_equiv_units(
self.poly_trend, sb2=self._sb2
)
self._v0_offsets_equiv_units = get_v0_offsets_equiv_units(self.n_offsets)
self._all_par_unit_equiv = {
**self._nonlinear_equiv_units,
Expand Down Expand Up @@ -291,10 +295,7 @@ def __repr__(self):
def __str__(self):
return ", ".join(self.par_names)

@deprecated_renamed_argument(
"random_state", "rng", since="v1.3", warning_type=DeprecationWarning
)
def sample(
def _get_raw_samples(
self,
size=1,
generate_linear=False,
Expand All @@ -303,29 +304,6 @@ def sample(
dtype=None,
**kwargs,
):
"""
Generate random samples from the prior.
Parameters
----------
size : int (optional)
The number of samples to generate.
generate_linear : bool (optional)
Also generate samples in the linear parameters.
return_logprobs : bool (optional)
Generate the log-prior probability at the position of each sample.
**kwargs
Additional keyword arguments are passed to the
`~thejoker.JokerSamples` initializer.
Returns
-------
samples : `thejoker.Jokersamples`
The random samples.
"""
from .samples import JokerSamples

if dtype is None:
dtype = np.float64

Expand All @@ -339,11 +317,6 @@ def sample(
)
}

if generate_linear:
par_names = self.par_names
else:
par_names = list(self._nonlinear_equiv_units.keys())

# MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
# init_shapes = {}
# for name, par in sub_pars.items():
Expand Down Expand Up @@ -374,12 +347,68 @@ def sample(

logp.append(_logp)
log_prior = np.sum(logp, axis=0)
else:
log_prior = None

# CONTINUED MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
# for name, par in sub_pars.items():
# if hasattr(par, "distribution"):
# par.distribution.shape = init_shapes[name]

return raw_samples, sub_pars, log_prior

@deprecated_renamed_argument(
"random_state", "rng", since="v1.3", warning_type=DeprecationWarning
)
def sample(
self,
size=1,
generate_linear=False,
return_logprobs=False,
rng=None,
dtype=None,
**kwargs,
):
"""
Generate random samples from the prior.
.. note::
Right now, generating samples with the prior values is slow (i.e.
with ``return_logprobs=True``) because of pymc3 issues (see
discussion here:
https://discourse.pymc.io/t/draw-values-speed-scaling-with-transformed-variables/4076).
This will hopefully be resolved in the future...
Parameters
----------
size : int (optional)
The number of samples to generate.
generate_linear : bool (optional)
Also generate samples in the linear parameters.
return_logprobs : bool (optional)
Generate the log-prior probability at the position of each sample.
**kwargs
Additional keyword arguments are passed to the
`~thejoker.JokerSamples` initializer.
Returns
-------
samples : `thejoker.Jokersamples`
The random samples.
"""
from thejoker.samples import JokerSamples

raw_samples, sub_pars, log_prior = self._get_raw_samples(
size, generate_linear, return_logprobs, rng, dtype, **kwargs
)

if generate_linear:
par_names = self.par_names
else:
par_names = list(self._nonlinear_equiv_units.keys())

# Apply units if they are specified:
prior_samples = JokerSamples(
poly_trend=self.poly_trend, n_offsets=self.n_offsets, **kwargs
Expand Down Expand Up @@ -448,9 +477,8 @@ def default_nonlinear_prior(P_min=None, P_max=None, s=None, model=None, pars=Non

if isinstance(s, pt.TensorVariable):
pars["s"] = pars.get("s", s)
else:
if not hasattr(s, "unit") or not s.unit.is_equivalent(u.km / u.s):
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")
elif not hasattr(s, "unit") or not s.unit.is_equivalent(u.km / u.s):
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")

# dictionary of parameters to return
out_pars = {}
Expand Down
17 changes: 12 additions & 5 deletions thejoker/prior_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,19 @@ def validate_poly_trend(poly_trend):
return poly_trend, vtrend_names


def get_linear_equiv_units(poly_trend):
def get_linear_equiv_units(poly_trend, sb2=False):
poly_trend, v_names = validate_poly_trend(poly_trend)
return {
'K': u.m/u.s,
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
}
if sb2:
return {
'K1': u.m/u.s,
'K2': u.m/u.s,
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
}
else:
return {
'K': u.m/u.s,
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
}


def validate_sigma_v(sigma_v, poly_trend, v_names):
Expand Down
Loading

0 comments on commit d6d0191

Please sign in to comment.