-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement proper truncation for prior distributions #335
base: develop
Are you sure you want to change the base?
Changes from 2 commits
2fd6da8
65ef80f
c01f2fb
d3b4e7f
057457f
1425d9c
155853f
2484a7f
a17aa62
6f005b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,10 @@ | |
}, | ||
{ | ||
"metadata": { | ||
"collapsed": true | ||
"collapsed": true, | ||
"jupyter": { | ||
"is_executing": true | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
|
@@ -42,7 +45,7 @@ | |
" if ax is None:\n", | ||
" fig, ax = plt.subplots()\n", | ||
"\n", | ||
" sample = prior.sample(10000)\n", | ||
" sample = prior.sample(20_000)\n", | ||
"\n", | ||
" # pdf\n", | ||
" xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n", | ||
|
@@ -138,11 +141,13 @@ | |
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"# different, because transformation!=LIN\n", | ||
"plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n", | ||
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n", | ||
"\n", | ||
"# same, because transformation=LIN\n", | ||
"plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n", | ||
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n" | ||
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))" | ||
], | ||
"id": "5ca940bc24312fc6", | ||
"outputs": [], | ||
|
@@ -151,15 +156,18 @@ | |
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):", | ||
"source": "The given distributions are truncated at the bounds defined in the parameter table:", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add something like "This results in a constant shift in the probability density, compared to the non-truncated version (https://en.wikipedia.org/wiki/Truncated_distribution), such that the probability density still sums to 1." |
||
"id": "b1a8b17d765db826" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n", | ||
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias" | ||
"plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n", | ||
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n", | ||
"plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n", | ||
"plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n", | ||
"plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))\n" | ||
], | ||
"id": "4ac42b1eed759bdd", | ||
"outputs": [], | ||
|
@@ -175,9 +183,11 @@ | |
"metadata": {}, | ||
"cell_type": "code", | ||
"source": [ | ||
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n", | ||
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n", | ||
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))" | ||
"plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n", | ||
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**9, 10**14), transformation=\"log10\"))\n", | ||
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n", | ||
"plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n", | ||
"plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))" | ||
], | ||
"id": "581e1ac431860419", | ||
"outputs": [], | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,15 +28,65 @@ class Distribution(abc.ABC): | |||||
If a float, the distribution is transformed to its corresponding | ||||||
log distribution with the given base (e.g., Normal -> Log10Normal). | ||||||
If ``False``, no transformation is applied. | ||||||
:param trunc: The truncation points (lower, upper) of the distribution | ||||||
or ``None`` if the distribution is not truncated. | ||||||
""" | ||||||
|
||||||
def __init__(self, log: bool | float = False): | ||||||
def __init__( | ||||||
self, *, log: bool | float = False, trunc: tuple[float, float] = None | ||||||
): | ||||||
if log is True: | ||||||
log = np.exp(1) | ||||||
|
||||||
if trunc == (-np.inf, np.inf): | ||||||
trunc = None | ||||||
|
||||||
if trunc is not None and trunc[0] > trunc[1]: | ||||||
dweindl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
raise ValueError( | ||||||
"The lower truncation limit must be smaller " | ||||||
"than the upper truncation limit." | ||||||
) | ||||||
|
||||||
self._logbase = log | ||||||
self._trunc = trunc | ||||||
|
||||||
self._cd_low = None | ||||||
self._cd_high = None | ||||||
self._truncation_normalizer = 1 | ||||||
|
||||||
if self._trunc is not None: | ||||||
try: | ||||||
# the cumulative density of the transformed distribution at the | ||||||
# truncation limits | ||||||
self._cd_low = self._cdf_transformed_untruncated( | ||||||
self.trunc_low | ||||||
) | ||||||
self._cd_high = self._cdf_transformed_untruncated( | ||||||
self.trunc_high | ||||||
) | ||||||
# normalization factor for the PDF of the transformed | ||||||
# distribution to account for truncation | ||||||
self._truncation_normalizer = 1 / ( | ||||||
self._cd_high - self._cd_low | ||||||
) | ||||||
except NotImplementedError: | ||||||
pass | ||||||
|
||||||
@property | ||||||
def trunc_low(self) -> float: | ||||||
"""The lower truncation limit of the transformed distribution.""" | ||||||
return self._trunc[0] if self._trunc else -np.inf | ||||||
|
||||||
@property | ||||||
def trunc_high(self) -> float: | ||||||
"""The upper truncation limit of the transformed distribution.""" | ||||||
return self._trunc[1] if self._trunc else np.inf | ||||||
|
||||||
def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float: | ||||||
"""Undo the log transformation. | ||||||
def _exp(self, x: np.ndarray | float) -> np.ndarray | float: | ||||||
"""Exponentiate / undo the log transformation according. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found it too complicated, as |
||||||
|
||||||
Exponentiate if a log transformation is applied to the distribution. | ||||||
Otherwise, return the input. | ||||||
|
||||||
:param x: The sample to transform. | ||||||
:return: The transformed sample | ||||||
|
@@ -45,9 +95,12 @@ def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float: | |||||
return x | ||||||
return self._logbase**x | ||||||
|
||||||
def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float: | ||||||
def _log(self, x: np.ndarray | float) -> np.ndarray | float: | ||||||
"""Apply the log transformation. | ||||||
dweindl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Compute the log of x with the specified base if a log transformation | ||||||
is applied to the distribution. Otherwise, return the input. | ||||||
|
||||||
:param x: The value to transform. | ||||||
:return: The transformed value. | ||||||
""" | ||||||
|
@@ -61,12 +114,17 @@ def sample(self, shape=None) -> np.ndarray: | |||||
:param shape: The shape of the sample. | ||||||
:return: A sample from the distribution. | ||||||
""" | ||||||
sample = self._sample(shape) | ||||||
return self._undo_log(sample) | ||||||
sample = ( | ||||||
self._exp(self._sample(shape)) | ||||||
if self._trunc is None | ||||||
else self._inverse_transform_sample(shape) | ||||||
) | ||||||
|
||||||
return sample | ||||||
|
||||||
@abc.abstractmethod | ||||||
def _sample(self, shape=None) -> np.ndarray: | ||||||
"""Sample from the underlying distribution. | ||||||
"""Sample from the underlying distribution, accounting for truncation. | ||||||
dweindl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:param shape: The shape of the sample. | ||||||
:return: A sample from the underlying distribution, | ||||||
|
@@ -85,7 +143,11 @@ def pdf(self, x): | |||||
chain_rule_factor = ( | ||||||
(1 / (x * np.log(self._logbase))) if self._logbase else 1 | ||||||
) | ||||||
return self._pdf(self._apply_log(x)) * chain_rule_factor | ||||||
return ( | ||||||
self._pdf(self._log(x)) | ||||||
* chain_rule_factor | ||||||
* self._truncation_normalizer | ||||||
) | ||||||
|
||||||
@abc.abstractmethod | ||||||
def _pdf(self, x): | ||||||
|
@@ -104,13 +166,71 @@ def logbase(self) -> bool | float: | |||||
""" | ||||||
return self._logbase | ||||||
|
||||||
def cdf(self, x): | ||||||
"""Cumulative distribution function at x. | ||||||
|
||||||
:param x: The value at which to evaluate the CDF. | ||||||
:return: The value of the CDF at ``x``. | ||||||
""" | ||||||
return self._cdf_transformed_untruncated(x) - self._cd_low | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, shouldn't the CDF "grow" faster when the PDF is truncated? e.g. for a normal distribution, the CDF reaches 1 at +infty. For a truncated normal distribution, the CDF reaches 1 in a finite interval... so is it enough to just subtract the lower bound CDF value? Could you add a test/sanity check that the CDF is 0 at the lower bound (trivially correct here), and 1 at the upper bound? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, I missed the normalization. Thanks, fixed. |
||||||
|
||||||
def _cdf_transformed_untruncated(self, x): | ||||||
"""Cumulative distribution function of the transformed, but untruncated | ||||||
distribution at x. | ||||||
|
||||||
:param x: The value at which to evaluate the CDF. | ||||||
:return: The value of the CDF at ``x``. | ||||||
""" | ||||||
return self._cdf_untransformed_untruncated(self._log(x)) | ||||||
|
||||||
def _cdf_untransformed_untruncated(self, x): | ||||||
"""Cumulative distribution function of the underlying | ||||||
(untransformed, untruncated) distribution at x. | ||||||
|
||||||
:param x: The value at which to evaluate the CDF. | ||||||
:return: The value of the CDF at ``x``. | ||||||
""" | ||||||
raise NotImplementedError | ||||||
|
||||||
def _ppf_untransformed_untruncated(self, q): | ||||||
"""Percent point function of the underlying | ||||||
(untransformed, untruncated) distribution at q. | ||||||
|
||||||
:param q: The quantile at which to evaluate the PPF. | ||||||
:return: The value of the PPF at ``q``. | ||||||
""" | ||||||
raise NotImplementedError | ||||||
|
||||||
def _ppf_transformed_untruncated(self, q): | ||||||
"""Percent point function of the transformed, but untruncated | ||||||
distribution at q. | ||||||
|
||||||
:param q: The quantile at which to evaluate the PPF. | ||||||
:return: The value of the PPF at ``q``. | ||||||
""" | ||||||
return self._exp(self._ppf_untransformed_untruncated(q)) | ||||||
|
||||||
def _inverse_transform_sample(self, shape): | ||||||
dweindl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Generate an inverse transform sample from the transformed and | ||||||
truncated distribution. | ||||||
|
||||||
:param shape: The shape of the sample. | ||||||
:return: The sample. | ||||||
""" | ||||||
uniform_sample = np.random.uniform( | ||||||
low=self._cd_low, high=self._cd_high, size=shape | ||||||
) | ||||||
return self._ppf_transformed_untruncated(uniform_sample) | ||||||
|
||||||
|
||||||
class Normal(Distribution): | ||||||
"""A (log-)normal distribution. | ||||||
|
||||||
:param loc: The location parameter of the distribution. | ||||||
:param scale: The scale parameter of the distribution. | ||||||
:param truncation: The truncation limits of the distribution. | ||||||
:param trunc: The truncation limits of the distribution. | ||||||
``None`` if the distribution is not truncated. The truncation limits | ||||||
are the truncation limits of the transformed distribution. | ||||||
Comment on lines
+277
to
+278
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Below, for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for the other distributions. |
||||||
:param log: If ``True``, the distribution is transformed to a log-normal | ||||||
distribution. If a float, the distribution is transformed to a | ||||||
log-normal distribution with the given base. | ||||||
|
@@ -124,19 +244,15 @@ def __init__( | |||||
self, | ||||||
loc: float, | ||||||
scale: float, | ||||||
truncation: tuple[float, float] | None = None, | ||||||
trunc: tuple[float, float] | None = None, | ||||||
log: bool | float = False, | ||||||
): | ||||||
super().__init__(log=log) | ||||||
self._loc = loc | ||||||
self._scale = scale | ||||||
self._truncation = truncation | ||||||
|
||||||
if truncation is not None: | ||||||
raise NotImplementedError("Truncation is not yet implemented.") | ||||||
super().__init__(log=log, trunc=trunc) | ||||||
|
||||||
def __repr__(self): | ||||||
trunc = f", truncation={self._truncation}" if self._truncation else "" | ||||||
trunc = f", trunc={self._trunc}" if self._trunc else "" | ||||||
log = f", log={self._logbase}" if self._logbase else "" | ||||||
return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})" | ||||||
|
||||||
|
@@ -146,6 +262,12 @@ def _sample(self, shape=None): | |||||
def _pdf(self, x): | ||||||
return norm.pdf(x, loc=self._loc, scale=self._scale) | ||||||
|
||||||
def _cdf_untransformed_untruncated(self, x): | ||||||
return norm.cdf(x, loc=self._loc, scale=self._scale) | ||||||
|
||||||
def _ppf_untransformed_untruncated(self, q): | ||||||
return norm.ppf(q, loc=self._loc, scale=self._scale) | ||||||
|
||||||
@property | ||||||
def loc(self): | ||||||
"""The location parameter of the underlying distribution.""" | ||||||
|
@@ -177,9 +299,9 @@ def __init__( | |||||
*, | ||||||
log: bool | float = False, | ||||||
): | ||||||
super().__init__(log=log) | ||||||
self._low = low | ||||||
self._high = high | ||||||
super().__init__(log=log) | ||||||
|
||||||
def __repr__(self): | ||||||
log = f", log={self._logbase}" if self._logbase else "" | ||||||
|
@@ -191,13 +313,21 @@ def _sample(self, shape=None): | |||||
def _pdf(self, x): | ||||||
return uniform.pdf(x, loc=self._low, scale=self._high - self._low) | ||||||
|
||||||
def _cdf_untransformed_untruncated(self, x): | ||||||
return uniform.cdf(x, loc=self._low, scale=self._high - self._low) | ||||||
|
||||||
def _ppf_untransformed_untruncated(self, q): | ||||||
return uniform.ppf(q, loc=self._low, scale=self._high - self._low) | ||||||
|
||||||
|
||||||
class Laplace(Distribution): | ||||||
"""A (log-)Laplace distribution. | ||||||
|
||||||
:param loc: The location parameter of the distribution. | ||||||
:param scale: The scale parameter of the distribution. | ||||||
:param truncation: The truncation limits of the distribution. | ||||||
:param trunc: The truncation limits of the distribution. | ||||||
``None`` if the distribution is not truncated. The truncation limits | ||||||
are the truncation limits of the transformed distribution. | ||||||
:param log: If ``True``, the distribution is transformed to a log-Laplace | ||||||
distribution. If a float, the distribution is transformed to a | ||||||
log-Laplace distribution with the given base. | ||||||
|
@@ -211,18 +341,15 @@ def __init__( | |||||
self, | ||||||
loc: float, | ||||||
scale: float, | ||||||
truncation: tuple[float, float] | None = None, | ||||||
trunc: tuple[float, float] | None = None, | ||||||
log: bool | float = False, | ||||||
): | ||||||
super().__init__(log=log) | ||||||
self._loc = loc | ||||||
self._scale = scale | ||||||
self._truncation = truncation | ||||||
if truncation is not None: | ||||||
raise NotImplementedError("Truncation is not yet implemented.") | ||||||
super().__init__(log=log, trunc=trunc) | ||||||
|
||||||
def __repr__(self): | ||||||
trunc = f", truncation={self._truncation}" if self._truncation else "" | ||||||
trunc = f", trunc={self._trunc}" if self._trunc else "" | ||||||
log = f", log={self._logbase}" if self._logbase else "" | ||||||
return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})" | ||||||
|
||||||
|
@@ -232,6 +359,12 @@ def _sample(self, shape=None): | |||||
def _pdf(self, x): | ||||||
return laplace.pdf(x, loc=self._loc, scale=self._scale) | ||||||
|
||||||
def _cdf_untransformed_untruncated(self, x): | ||||||
return laplace.cdf(x, loc=self._loc, scale=self._scale) | ||||||
|
||||||
def _ppf_untransformed_untruncated(self, q): | ||||||
return laplace.ppf(q, loc=self._loc, scale=self._scale) | ||||||
|
||||||
@property | ||||||
def loc(self): | ||||||
"""The location parameter of the underlying distribution.""" | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to some v1 subfolder? Now or later is fine. But I think priors will change a lot in v2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about moving it to https://github.com/PEtab-dev/PEtab/ at some point. It might also be helpful for non-python petab users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!