Skip to content

Commit

Permalink
Implement proper truncation for prior distributions
Browse files Browse the repository at this point in the history
Currently, when sampled startpoints are outside the bounds, their value is set to the upper/lower bounds. This may put too much probability mass on the bounds.

With these changes, we properly sample from the respective truncated distributions.
  • Loading branch information
dweindl committed Dec 9, 2024
1 parent a272255 commit 90946f3
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 78 deletions.
49 changes: 33 additions & 16 deletions doc/example/distributions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
" if ax is None:\n",
" fig, ax = plt.subplots()\n",
"\n",
" sample = distr.sample(10000)\n",
" sample = distr.sample(20_000)\n",
"\n",
" # pdf\n",
" xmin = min(sample.min(), distr.lb_scaled if distr.bounds is not None else sample.min())\n",
Expand Down Expand Up @@ -102,9 +102,8 @@
"cell_type": "code",
"source": [
"plot(Normal(10, 2, transformation=LIN))\n",
"plot(Normal(10, 2, transformation=LOG))\n",
"# Note that the log-normal is different from the log-transformed normal distribution:\n",
"plot(LogNormal(10, 2, transformation=LIN))"
"# Note that log-transformed normal distribution is different from the log-normal distribution\n",
"plot(Normal(10, 2, transformation=LOG))"
],
"id": "f6192c226f179ef9",
"outputs": [],
Expand All @@ -120,11 +119,13 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Uniform(0, 1, transformation=LOG10))\n",
"plot(ParameterScaleUniform(0, 1, transformation=LOG10))\n",
"# different, because transformation!=LIN\n",
"plot(Uniform(1e-16, 1, transformation=LOG10))\n",
"plot(ParameterScaleUniform(1e-16, 1, transformation=LOG10))\n",
"\n",
"plot(Uniform(0, 1, transformation=LIN))\n",
"plot(ParameterScaleUniform(0, 1, transformation=LIN))\n"
"# same, because transformation=LIN\n",
"plot(Uniform(1e-16, 1, transformation=LIN))\n",
"plot(ParameterScaleUniform(1e-16, 1, transformation=LIN))\n"
],
"id": "5ca940bc24312fc6",
"outputs": [],
Expand All @@ -133,15 +134,18 @@
{
"metadata": {},
"cell_type": "markdown",
"source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):",
"source": "The given distributions are truncated at the bounds defined in the parameter table:",
"id": "b1a8b17d765db826"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"plot(Normal(0, 1, bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n",
"plot(Uniform(0, 1, bounds=(0.1, 0.9))) # significant clipping-bias"
"plot(Normal(0, 1, bounds=(-2, 2)))\n",
"plot(Uniform(0, 1, bounds=(0.1, 0.9)))\n",
"plot(Uniform(1e-8, 1, bounds=(0.1, 0.9), transformation=LOG10))\n",
"plot(Laplace(0, 1, bounds=(-0.5, 0.5)))\n",
"plot(ParameterScaleUniform(-3, 1, bounds=(1e-2, 1), transformation=LOG10))\n"
],
"id": "4ac42b1eed759bdd",
"outputs": [],
Expand All @@ -156,21 +160,34 @@
{
"metadata": {},
"cell_type": "code",
"source": [
"plot(Normal(10, 1, bounds=(6, 14), transformation=\"log10\"))\n",
"plot(ParameterScaleNormal(10, 1, bounds=(10**6, 10**14), transformation=\"log10\"))\n"
],
"source": "plot(Normal(10, 1, bounds=(6, 14), transformation=\"log10\"))",
"id": "581e1ac431860419",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"source": "plot(ParameterScaleNormal(10, 1, bounds=(10**6, 10**14), transformation=\"log10\"))",
"id": "99202ecb47706a68",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "plot(LogLaplace(1, 0.5, bounds=(0.5, 8)))",
"id": "802a64be56a6c94f",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "plot(LogNormal(2, 1, bounds=(0.5, 8)))",
"id": "7820e93ab9b2fb47",
"outputs": [],
"execution_count": null
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 90946f3

Please sign in to comment.