Skip to content

Commit

Permalink
optional
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 11, 2024
1 parent 2fd6da8 commit 65ef80f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
25 changes: 19 additions & 6 deletions petab/v1/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@

__all__ = ["priors_to_measurements"]

# TODO: does anybody really rely on the old behavior?
USE_PROPER_TRUNCATION = True


class Prior:
"""A PEtab parameter prior.
Expand All @@ -61,6 +58,15 @@ class Prior:
on the `parameter_scale` scale).
:param bounds: The untransformed bounds of the sample (lower, upper).
:param transformation: The transformation of the distribution.
:param bounds_truncate: Whether the generated prior will be truncated
at the bounds.
If ``True``, the probability density will be rescaled
accordingly and the sample is generated from the truncated
distribution.
If ``False``, the probability density will not account for the
bounds, but any parameter samples outside the bounds will be set to
the value of the closest bound. In this case, the PDF might not match
the sample.
"""

def __init__(
Expand All @@ -69,6 +75,7 @@ def __init__(
parameters: tuple,
bounds: tuple = None,
transformation: str = C.LIN,
bounds_truncate: bool = True,
):
if transformation not in C.PARAMETER_SCALES:
raise ValueError(
Expand All @@ -90,8 +97,9 @@ def __init__(
self._parameters = parameters
self._bounds = bounds
self._transformation = transformation
self._bounds_truncate = bounds_truncate

truncation = bounds if USE_PROPER_TRUNCATION else None
truncation = bounds if bounds_truncate else None
if truncation is not None:
# for uniform, we don't want to implement truncation and just
# adapt the distribution parameters
Expand Down Expand Up @@ -184,7 +192,7 @@ def _clip_to_bounds(self, x):
:param x: The values to clip. Assumed to be on the parameter scale.
"""
if self.bounds is None or USE_PROPER_TRUNCATION:
if self.bounds is None or self._bounds_truncate:
return x

return np.maximum(
Expand Down Expand Up @@ -235,12 +243,16 @@ def neglogprior(self, x):

@staticmethod
def from_par_dict(
d, type_=Literal["initialization", "objective"]
d,
type_=Literal["initialization", "objective"],
bounds_truncate: bool = True,
) -> Prior:
"""Create a distribution from a row of the parameter table.
:param d: A dictionary representing a row of the parameter table.
:param type_: The type of the distribution.
:param bounds_truncate: Whether the generated prior will be truncated
at the bounds.
:return: A distribution object.
"""
dist_type = d.get(f"{type_}PriorType", C.PARAMETER_SCALE_UNIFORM)
Expand Down Expand Up @@ -268,6 +280,7 @@ def from_par_dict(
parameters=params,
bounds=(d[C.LOWER_BOUND], d[C.UPPER_BOUND]),
transformation=pscale,
bounds_truncate=bounds_truncate,
)


Expand Down
10 changes: 8 additions & 2 deletions petab/v1/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def sample_from_prior(
# unpack info
p_type, p_params, scaling, bounds = prior
prior = Prior(
p_type, tuple(p_params), bounds=tuple(bounds), transformation=scaling
p_type,
tuple(p_params),
bounds=tuple(bounds),
transformation=scaling,
bounds_truncate=True,
)
return prior.sample(shape=(n_starts,))

Expand Down Expand Up @@ -74,7 +78,9 @@ def sample_parameter_startpoints(
# get types and parameters of priors from dataframe
return np.array(
[
Prior.from_par_dict(row, type_="initialization").sample(n_starts)
Prior.from_par_dict(
row, type_="initialization", bounds_truncate=True
).sample(n_starts)
for row in par_to_estimate.to_dict("records")
]
).T
4 changes: 3 additions & 1 deletion tests/v1/test_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def apply_parameter_values(row):
]
priors = [
Prior.from_par_dict(
petab_problem_priors.parameter_df.loc[par_id], type_="objective"
petab_problem_priors.parameter_df.loc[par_id],
type_="objective",
bounds_truncate=False,
)
for par_id in parameter_ids
]
Expand Down

0 comments on commit 65ef80f

Please sign in to comment.