Skip to content

Commit

Permalink
corrected looph documentation (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwpriest authored Sep 13, 2023
1 parent 635f18c commit 1e6cd0e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 56 deletions.
21 changes: 12 additions & 9 deletions MuyGPyS/optimize/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def mse_fn(
function computes
.. math::
l(f(x), y \\mid \\sigma) = \\frac{1}{b} \\sum_{i=1}^b (f(x_i) - y)^2
\\ell_\\textrm{MSE}(f(x), y) = \\frac{1}{b} \\sum_{i=1}^b (f(x_i) - y)^2
Args:
predictions:
Expand Down Expand Up @@ -121,8 +121,9 @@ def lool_fn(
penalty. The function computes
.. math::
l(f(x), y \\mid \\sigma) = \\sum_{i=1}^b \\sum_{j=1}^s
\\frac{(f(x_i) - y)^2}{\\sigma_j} + \\log \\sigma_j
\\ell_\\textrm{lool}(f(x), y \\mid \\sigma^2) =
\\sum_{i=1}^b \\sum_{j=1}^s
\\left ( \\frac{(f(x_i) - y)}{\\sigma_j} \\right )^2 + \\log \\sigma_j^2
Args:
predictions:
Expand Down Expand Up @@ -154,8 +155,8 @@ def lool_fn_unscaled(
function computes
.. math::
l(f(x), y \\mid \\sigma) = \\sum_{i=1}^b
\\frac{(f(x_i) - y)^2}{\\sigma} + \\log \\sigma
\\ell_\\textrm{lool}(f(x), y \\mid \\sigma^2) = \\sum_{i=1}^b
\\left ( \\frac{(f(x_i) - y)}{\\sigma_i} \\right )^2 + \\log \\sigma_i^2
Args:
predictions:
Expand Down Expand Up @@ -186,7 +187,8 @@ def pseudo_huber_fn(
The function computes
.. math::
l(f(x), y \\mid \\delta) = \\delta^2 \\sum_{i=1}^b \\left ( \\sqrt{
\\ell_\\textrm{Pseudo-Huber}(f(x), y \\mid \\delta) =
\\sum_{i=1}^b \\delta^2 \\left ( \\sqrt{
1 + \\left ( \\frac{y_i - f(x_i)}{\\delta} \\right )^2
} - 1 \\right )
Expand Down Expand Up @@ -222,9 +224,10 @@ def looph_fn(
variance. The function computes
.. math::
l(f(x), y \\mid \\delta) = \\delta^2 \\sum_{i=1}^b \\left ( \\sqrt{
1 + \\left ( \\frac{y_i - f(x_i)}{\\sigma_i \\delta} \\right )^2
} - 1 \\right ) + \\log \\sigma_i
\\ell_\\textrm{lool}(f(x), y \\mid \\delta, \\sigma^2) =
\\sum_{i=1}^b \\delta^2 \\left ( \\sqrt{
1 + \\left ( \\frac{y_i - f(x_i)}{\\delta \\sigma_i^2} \\right )^2
} - 1 \\right ) + \\log \\sigma_i^2
Args:
predictions:
Expand Down
111 changes: 64 additions & 47 deletions docs/examples/loss_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@
"from MuyGPyS._src.optimize.loss.numpy import _looph_fn_unscaled"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.style.use('tableau-colorblind10')"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -62,7 +71,7 @@
"mmax = 3.0\n",
"point_count = 50\n",
"ys = np.zeros(point_count)\n",
"mus = np.linspace(-mmax, mmax, point_count)\n",
"residuals = np.linspace(-mmax, mmax, point_count)\n",
"smax = 3.0\n",
"smin = 1e-1\n",
"sigma_count = 50\n",
Expand All @@ -75,9 +84,9 @@
"source": [
"## Variance-free Loss Functions\n",
"\n",
"`MuyGPyS` features several loss functions that depend only upon $y$ and $\\mu$ of your training batch.\n",
"`MuyGPyS` features several loss functions that depend only upon the targets $y$ and posterior mean predictions $\\mu$ of your training batch.\n",
"These loss functions are situationally useful, although they leave the fitting of variance parameters entirely up to the separate, analytic `sigma_sq` optimization function and might not be sensitive to certain variance parameters.\n",
"As they do not require evaluating $\\sigma$ or optimizing the variance scaling parameter, these loss functions are generally more efficient to use in practice."
"As they do not require evaluating the posterior variance $\\sigma^2$ or optimizing the variance scaling parameter, these loss functions are generally more efficient to use in practice."
]
},
{
Expand All @@ -89,7 +98,7 @@
"The mean squared error (MSE) is a classic loss that computes\n",
"\n",
"\\begin{equation*}\n",
"\\frac{1}{b} \\sum_{i \\in B} (\\mu_i - y)^2.\n",
"\\ell_\\textrm{MSE}(\\mu, y) = \\frac{1}{b} \\sum_{i \\in B} (\\mu_i - y_i)^2.\n",
"\\end{equation*}\n",
"\n",
"The string used to indicate MSE loss in optimization functions is `\"mse\"`.\n",
Expand All @@ -106,8 +115,8 @@
"ax.set_title(\"MSE as a function of the residual\", fontsize=20)\n",
"ax.set_ylabel(\"loss\", fontsize=15)\n",
"ax.set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
"mses = [mse_fn(ys[i].reshape(1, 1), mus[i].reshape(1, 1)) for i in range(point_count)]\n",
"ax.plot(mus, mses)\n",
"mses = [mse_fn(ys[i].reshape(1, 1), residuals[i].reshape(1, 1)) for i in range(point_count)]\n",
"ax.plot(residuals, mses)\n",
"plt.show()"
]
},
Expand All @@ -134,14 +143,15 @@
"The pseudo-Huber loss computes\n",
"\n",
"\\begin{equation*}\n",
"\\delta^2 \\sum_{i=1}^b \\left ( \n",
" \\sqrt{1 + \\left ( \\frac{\\mu_i - y_i}{\\delta} \\right )^2} - 1\n",
"\\ell_\\textrm{Pseudo-Huber}(\\mu, y \\mid \\delta) =\n",
"\\sum_{i=1}^b \\delta^2 \\left ( \n",
"\\sqrt{1 + \\left ( \\frac{\\mu_i - y_i}{\\delta} \\right )^2} - 1\n",
"\\right ),\n",
"\\end{equation*}\n",
"\n",
"where $\\delta$ is a parameter that indicates the scale of the boundary between the quadratic and linear parts of the function.\n",
"The `pseudo_huber_fn` accepts this parameter as the `boundary_scale` keyword argument.\n",
"Note that the scale of $\\delta$ depends on the units of $y$ and $mu$.\n",
"Note that the scale of $\\delta$ depends on the units of $y$ and $\\mu$.\n",
"The following plots show the behavior of the pseudo-Huber loss for a few values of $\\delta$."
]
},
Expand All @@ -151,17 +161,20 @@
"metadata": {},
"outputs": [],
"source": [
"boundary_scales = [0.5, 1.5, 2.5]\n",
"boundary_scales = [0.5, 1.0, 2.5]\n",
"phs = np.array([\n",
" [pseudo_huber_fn(ys[i].reshape(1, 1), mus[i].reshape(1, 1), boundary_scale=bs) for i in range(point_count)]\n",
" [pseudo_huber_fn(ys[i].reshape(1, 1), residuals[i].reshape(1, 1), boundary_scale=bs) for i in range(point_count)]\n",
" for bs in boundary_scales\n",
"])\n",
"fig, axes = plt.subplots(1, 3, figsize=(14, 3))\n",
"for i, ax in enumerate(axes):\n",
" ax.set_title(f\"Pseudo-Huber with $\\delta$={boundary_scales[i]}\", fontsize=20)\n",
" ax.set_ylabel(\"loss\", fontsize=15)\n",
" ax.set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
" ax.plot(mus, phs[i, :])\n",
"fig, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
"# for i, ax in enumerate(axes):\n",
"ax.set_title(f\"Pseudo-Huber\", fontsize=20)\n",
"ax.set_ylabel(\"loss\", fontsize=15)\n",
"ax.set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
"ax.plot(residuals, phs[0, :], linestyle=\"solid\", label=f\"$\\delta = {boundary_scales[0]}$\")\n",
"ax.plot(residuals, phs[1, :], linestyle=\"dotted\", label=f\"$\\delta = {boundary_scales[1]}$\")\n",
"ax.plot(residuals, phs[2, :], linestyle=\"dashed\", label=f\"$\\delta = {boundary_scales[2]}$\")\n",
"ax.legend()\n",
"plt.show()"
]
},
Expand All @@ -171,7 +184,7 @@
"source": [
"## Variance-Sensitive Loss Functions\n",
"\n",
"`MuyGPyS` also includes loss functions that explicitly depend upon the posterior variances ($\\sigma$).\n",
"`MuyGPyS` also includes loss functions that explicitly depend upon the posterior variances $\\sigma^2$.\n",
"These loss functions penalize large variances, and so tend to be more sensitive to variance parameters.\n",
"This comes at increasing the cost of the linear algebra involved in each evaluation of the objective function by a constant factor.\n",
"This causes an overall increase in compute time per optimization loop, but that is often worth the trade for sensitivity in practice.\n",
Expand All @@ -189,7 +202,8 @@
"lool computes \n",
"\n",
"\\begin{equation*}\n",
"\\sum_{i \\in B} \\frac{(\\mu_i - y_i)^2}{\\sigma_i} + \\log \\sigma_i.\n",
"\\ell_\\textrm{lool}(\\mu, y \\mid \\sigma^2) = \n",
"\\sum_{i \\in B} \\left ( \\frac{\\mu_i - y_i}{\\sigma_i} \\right )^2 + \\log \\sigma_i^2.\n",
"\\end{equation*}\n",
"\n",
"The next plot illustrates the loss as a function of both the residual and of $\\sigma$."
Expand All @@ -205,7 +219,7 @@
" [\n",
" lool_fn_unscaled(\n",
" ys[i].reshape(1, 1),\n",
" mus[i].reshape(1, 1),\n",
" residuals[i].reshape(1, 1),\n",
" sigmas[sigma_count - 1 - j]\n",
" )\n",
" for i in range(point_count)\n",
Expand All @@ -220,29 +234,30 @@
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(19,4))\n",
"fig, axes = plt.subplots(1, 3, figsize=(14, 4))\n",
"axes[0].set_title(\"lool\", fontsize=20)\n",
"axes[0].set_ylabel(\"$\\sigma_i$\", fontsize=15)\n",
"axes[0].set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
"im = axes[0].imshow(\n",
" lools, extent=[-mmax, mmax, smin, smax], norm=SymLogNorm(0.5), cmap=\"coolwarm\"\n",
" lools, extent=[-mmax, mmax, smin, smax], norm=SymLogNorm(1e-1), cmap=\"coolwarm\", aspect=2.0\n",
")\n",
"fig.colorbar(im, ax=axes[0])\n",
"\n",
"axes[1].set_title(\"lool, $\\sigma_i=0.5$\", fontsize=20)\n",
"axes[1].set_title(\"lool residual cross-section\", fontsize=14)\n",
"axes[1].set_ylabel(\"lool\", fontsize=15)\n",
"axes[1].set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
"axes[1].plot(mus, lools[7, :])\n",
"axes[1].plot(residuals, lools[15, :], linestyle=\"solid\", label=\"$\\sigma_i = 1.0$\")\n",
"axes[1].plot(residuals, lools[7, :], linestyle=\"dotted\", label=\"$\\sigma_i = 0.5$\")\n",
"axes[1].plot(residuals, lools[0, :], linestyle=\"dashed\", label=\"$\\sigma_i = 0.1$\")\n",
"axes[1].legend()\n",
"\n",
"axes[2].set_title(\"lool, $\\mid \\mu_i - y_i \\mid=1.0$\", fontsize=20)\n",
"axes[2].set_title(\"lool, variance cross-section\", fontsize=14)\n",
"axes[2].set_ylabel(\"lool\", fontsize=15)\n",
"axes[2].set_xlabel(\"$\\sigma_i$\", fontsize=15)\n",
"axes[2].plot(sigmas, np.flip(lools[:, 33]))\n",
"\n",
"axes[3].set_title(\"lool, $\\mid \\mu_i - y_i \\mid=0.0$\", fontsize=20)\n",
"axes[3].set_ylabel(\"lool\", fontsize=15)\n",
"axes[3].set_xlabel(\"$\\sigma_i$\", fontsize=15)\n",
"axes[3].plot(sigmas, np.flip(lools[:, 24]))\n",
"axes[2].plot(sigmas, np.flip(lools[:, 33]), linestyle=\"solid\", label=\"$\\mid \\mu_i - y_i \\mid = 1.0$\")\n",
"axes[2].plot(sigmas, np.flip(lools[:, 29]), linestyle=\"dotted\", label=\"$\\mid \\mu_i - y_i \\mid = 0.5$\")\n",
"axes[2].plot(sigmas, np.flip(lools[:, 24]), linestyle=\"dashed\", label=\"$\\mid \\mu_i - y_i \\mid = 0.0$\")\n",
"axes[2].legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
Expand All @@ -266,9 +281,10 @@
"looph computes\n",
"\n",
"\\begin{equation*}\n",
"\\delta^2 \\sum_{i=1}^b \\left ( \n",
" \\sqrt{1 + \\left ( \\frac{\\mu_i - y_i}{\\sigma_i \\delta} \\right )^2} - 1\n",
"\\right ) + \\log \\sigma_i,\n",
"\\ell_\\textrm{looph}(\\mu, y \\mid \\delta, \\sigma^2) =\n",
"\\sum_{i=1}^b \\delta^2 \\left ( \n",
"\\sqrt{1 + \\left ( \\frac{\\mu_i - y_i}{\\delta \\sigma_i^2} \\right )^2} - 1\n",
"\\right ) + \\log \\sigma_i^2,\n",
"\\end{equation*}\n",
"\n",
"where again $\\delta$ is the boundary scale.\n",
Expand All @@ -286,7 +302,7 @@
" [\n",
" _looph_fn_unscaled(\n",
" ys[i].reshape(1, 1),\n",
" mus[i].reshape(1, 1),\n",
" residuals[i].reshape(1, 1),\n",
" sigmas[sigma_count - 1 - j],\n",
" boundary_scale=bs\n",
" )\n",
Expand All @@ -304,30 +320,31 @@
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(3, 4, figsize=(19,14))\n",
"fig, axes = plt.subplots(3, 3, figsize=(14,12))\n",
"for i, bs in enumerate(boundary_scales):\n",
" axes[i, 0].set_title(f\"looph, $\\delta={bs}$\", fontsize=20)\n",
" axes[i, 0].set_title(f\"looph ($\\delta={bs}$)\", fontsize=20)\n",
" axes[i, 0].set_ylabel(\"$\\sigma_i$\", fontsize=15)\n",
" axes[i, 0].set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
" im = axes[i, 0].imshow(\n",
" lools, extent=[-mmax, mmax, smin, smax], norm=SymLogNorm(0.5), cmap=\"coolwarm\"\n",
" loophs[i, :, :], extent=[-mmax, mmax, smin, smax], norm=SymLogNorm(1e-1), cmap=\"coolwarm\", aspect=2.0\n",
" )\n",
" fig.colorbar(im, ax=axes[i, 0])\n",
"\n",
" axes[i, 1].set_title(f\"looph, $\\delta={bs}$, $\\sigma_i=0.5$\", fontsize=20)\n",
" axes[i, 1].set_title(f\"looph residual cross-section ($\\delta={bs}$)\", fontsize=14)\n",
" axes[i, 1].set_ylabel(\"looph\", fontsize=15)\n",
" axes[i, 1].set_xlabel(\"$\\mu_i - y_i$\", fontsize=15)\n",
" axes[i, 1].plot(mus, loophs[i, 7, :])\n",
" axes[i, 1].plot(residuals, loophs[i, 15, :], linestyle=\"solid\", label=\"$\\sigma_i = 1.0$\")\n",
" axes[i, 1].plot(residuals, loophs[i, 7, :], linestyle=\"dotted\", label=\"$\\sigma_i = 0.5$\")\n",
" axes[i, 1].plot(residuals, loophs[i, 0, :], linestyle=\"dashed\", label=\"$\\sigma_i = 0.1$\")\n",
" axes[i, 1].legend()\n",
"\n",
" axes[i, 2].set_title(f\"looph, $\\delta={bs}$, $\\mid \\mu_i - y_i \\mid=1.0$\", fontsize=20)\n",
" axes[i, 2].set_title(f\"looph variance cross-section ($\\delta={bs}$)\", fontsize=14)\n",
" axes[i, 2].set_ylabel(\"looph\", fontsize=15)\n",
" axes[i, 2].set_xlabel(\"$\\sigma_i$\", fontsize=15)\n",
" axes[i, 2].plot(sigmas, np.flip(loophs[i, :, 33]))\n",
"\n",
" axes[i, 3].set_title(f\"looph, $\\delta={bs}$, $\\mid \\mu_i - y_i \\mid=0.0$\", fontsize=20)\n",
" axes[i, 3].set_ylabel(\"looph\", fontsize=15)\n",
" axes[i, 3].set_xlabel(\"$\\sigma_i$\", fontsize=15)\n",
" axes[i, 3].plot(sigmas, np.flip(loophs[i, :, 24]))\n",
" axes[i, 2].plot(sigmas, np.flip(loophs[i, :, 33]), linestyle=\"solid\", label=\"$\\mid \\mu_i - y_i \\mid = 1.0$\")\n",
" axes[i, 2].plot(sigmas, np.flip(loophs[i, :, 29]), linestyle=\"dotted\", label=\"$\\mid \\mu_i - y_i \\mid = 0.5$\")\n",
" axes[i, 2].plot(sigmas, np.flip(loophs[i, :, 24]), linestyle=\"dashed\", label=\"$\\mid \\mu_i - y_i \\mid = 0.0$\")\n",
" axes[i, 2].legend()\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
Expand Down

0 comments on commit 1e6cd0e

Please sign in to comment.