Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement proper truncation for prior distributions #335

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
nb_execution_mode = "force"
nb_execution_raise_on_error = True
nb_execution_show_tb = True
nb_execution_timeout = 90 # max. seconds/cell

source_suffix = {
".rst": "restructuredtext",
Expand Down
93 changes: 67 additions & 26 deletions doc/example/distributions.ipynb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to some v1 subfolder? Now or later is fine. But I think priors will change a lot in v2

Copy link
Member Author

@dweindl dweindl Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about moving it to https://github.com/PEtab-dev/PEtab/ at some point. It might also be helpful for non-python petab users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,73 @@
"\n",
"from petab.v1.C import *\n",
"from petab.v1.priors import Prior\n",
"from petab.v1.parameters import scale, unscale\n",
"\n",
"\n",
"sns.set_style(None)\n",
"\n",
"\n",
"def plot(prior: Prior, ax=None):\n",
"def plot(prior: Prior):\n",
" \"\"\"Visualize a distribution.\"\"\"\n",
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
" sample = prior.sample(20_000, x_scaled=True)\n",
"\n",
" fig.suptitle(str(prior))\n",
"\n",
" plot_single(prior, ax=ax1, sample=sample, scaled=False)\n",
" plot_single(prior, ax=ax2, sample=sample, scaled=True)\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"def plot_single(prior: Prior, scaled: bool = False, ax=None, sample: np.array = None):\n",
" fig = None\n",
" if ax is None:\n",
" fig, ax = plt.subplots()\n",
"\n",
" sample = prior.sample(10000)\n",
"\n",
" # pdf\n",
" xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n",
" xmax = max(sample.max(), prior.ub_scaled if prior.bounds is not None else sample.max())\n",
" if sample is None:\n",
" sample = prior.sample(20_000)\n",
"\n",
" # assuming scaled sample\n",
" if not scaled:\n",
" sample = unscale(sample, prior.transformation)\n",
" bounds = prior.bounds\n",
" else:\n",
" bounds = (prior.lb_scaled, prior.ub_scaled) if prior.bounds is not None else None\n",
"\n",
" # plot pdf\n",
" xmin = min(sample.min(), bounds[0] if prior.bounds is not None else sample.min())\n",
" xmax = max(sample.max(), bounds[1] if prior.bounds is not None else sample.max())\n",
" padding = 0.1 * (xmax - xmin)\n",
" xmin -= padding\n",
" xmax += padding\n",
" x = np.linspace(xmin, xmax, 500)\n",
" y = prior.pdf(x)\n",
" y = prior.pdf(x, x_scaled=scaled, rescale=scaled)\n",
" ax.plot(x, y, color='red', label='pdf')\n",
"\n",
" sns.histplot(sample, stat='density', ax=ax, label=\"sample\")\n",
"\n",
" # bounds\n",
" # plot bounds\n",
" if prior.bounds is not None:\n",
" for bound in (prior.lb_scaled, prior.ub_scaled):\n",
" for bound in bounds:\n",
" if bound is not None and np.isfinite(bound):\n",
" ax.axvline(bound, color='black', linestyle='--', label='bound')\n",
"\n",
" ax.set_title(str(prior))\n",
" ax.set_xlabel('Parameter value on the parameter scale')\n",
" if fig is not None:\n",
" ax.set_title(str(prior))\n",
"\n",
" if scaled:\n",
" ax.set_xlabel(f'Parameter value on parameter scale ({prior.transformation})')\n",
" ax.set_ylabel(\"Rescaled density\")\n",
" else:\n",
" ax.set_xlabel('Parameter value')\n",
"\n",
" ax.grid(False)\n",
" handles, labels = ax.get_legend_handles_labels()\n",
" unique_labels = dict(zip(labels, handles))\n",
" ax.legend(unique_labels.values(), unique_labels.keys())\n",
" plt.show()"
"\n",
" if ax is None:\n",
" plt.show()\n"
],
"id": "initial_id",
"outputs": [],
Expand All @@ -81,11 +115,11 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Prior(UNIFORM, (0, 1)))\n",
"plot(Prior(NORMAL, (0, 1)))\n",
"plot(Prior(LAPLACE, (0, 1)))\n",
"plot(Prior(LOG_NORMAL, (0, 1)))\n",
"plot(Prior(LOG_LAPLACE, (1, 0.5)))"
"plot_single(Prior(UNIFORM, (0, 1)))\n",
"plot_single(Prior(NORMAL, (0, 1)))\n",
"plot_single(Prior(LAPLACE, (0, 1)))\n",
"plot_single(Prior(LOG_NORMAL, (0, 1)))\n",
"plot_single(Prior(LOG_LAPLACE, (1, 0.5)))"
],
"id": "4f09e50a3db06d9f",
"outputs": [],
Expand All @@ -94,7 +128,7 @@
{
"metadata": {},
"cell_type": "markdown",
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10` not a `parameterScale*`-type distribution), the sample is transformed accordingly (but not the distribution parameters):\n",
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10`) and the chosen distribution is not a `parameterScale*`-type distribution, then the distribution parameters are taken as is, i.e., the `parameterScale` is not applied to the distribution parameters. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10`) and the chosen distribution is not a `parameterScale*`-type distribution, then the distribution parameters are taken as is, i.e., the `parameterScale` is not applied to the distribution parameters. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n",
"source": "If a parameter scale is specified (`parameterScale=lin|log|log10`), the distribution parameters are used as is without applying the `parameterScale` to them. The exception are the `parameterScale*`-type distributions, as explained below. In the context of PEtab prior distributions, `parameterScale` will only be used for the start point sampling for optimization, where the sample will be transformed accordingly. This is demonstrated below. The left plot always shows the prior distribution for unscaled parameter values, and the right plot shows the prior distribution for scaled parameter values. Note that in the objective function, the prior is always on the unscaled parameters.\n",

Just a suggestion to make it easier to follow, I also found this easier to understand once I understood that parameterScale* is explained afterwards.

"id": "dab4b2d1e0f312d8"
},
{
Expand Down Expand Up @@ -131,18 +165,20 @@
{
"metadata": {},
"cell_type": "markdown",
"source": "Prior distributions can also be defined on the parameter scale by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, 1) the distribution parameter are interpreted on the transformed parameter scale, and 2) a sample from the given distribution is used directly, without applying any transformation according to `parameterScale` (this implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`):",
"source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameter are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameter are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`.",
"source": "Prior distributions can also be defined on the scaled parameters (i.e., transformed according to `parameterScale`) by using the types `parameterScaleUniform`, `parameterScaleNormal` or `parameterScaleLaplace`. In these cases, the distribution parameters are interpreted on the transformed parameter scale (but not the parameter bounds, see below). This implies, that for `parameterScale=lin`, there is no difference between `parameterScaleUniform` and `uniform`.",

"id": "263c9fd31156a4d5"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"# different, because transformation!=LIN\n",
"plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n",
"\n",
"# same, because transformation=LIN\n",
"plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n"
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))"
],
"id": "5ca940bc24312fc6",
"outputs": [],
Expand All @@ -151,15 +187,18 @@
{
"metadata": {},
"cell_type": "markdown",
"source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):",
"source": "The given distributions are truncated at the bounds defined in the parameter table:",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add something like "This results in a constant shift in the probability density, compared to the non-truncated version (https://en.wikipedia.org/wiki/Truncated_distribution), such that the probability density still sums to 1."

"id": "b1a8b17d765db826"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n",
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias"
"plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n",
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n",
"plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n",
"plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))"
],
"id": "4ac42b1eed759bdd",
"outputs": [],
Expand All @@ -175,9 +214,11 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n",
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))"
"plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n",
"plot(Prior(PARAMETER_SCALE_NORMAL, (2, 1), bounds=(10**0, 10**3), transformation=\"log10\"))\n",
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
"plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n",
"plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))"
],
"id": "581e1ac431860419",
"outputs": [],
Expand Down
7 changes: 7 additions & 0 deletions petab/v1/C.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@
PARAMETER_SCALE_LAPLACE,
]

#: parameterScale*-type prior distributions
PARAMETER_SCALE_PRIOR_TYPES = [
PARAMETER_SCALE_UNIFORM,
PARAMETER_SCALE_NORMAL,
PARAMETER_SCALE_LAPLACE,
]

#: Supported noise distributions
NOISE_MODELS = [NORMAL, LAPLACE]

Expand Down
Loading
Loading