diff --git a/examples/2D_tutorials/Maximum_likelihood_CNF_tutorial.ipynb b/examples/2D_tutorials/Maximum_likelihood_CNF_tutorial.ipynb new file mode 100644 index 0000000..c861497 --- /dev/null +++ b/examples/2D_tutorials/Maximum_likelihood_CNF_tutorial.ipynb @@ -0,0 +1,197 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fb2b2856", + "metadata": {}, + "source": [ + "# Continuous Normalizing Flow tutorial: training ODE generative models using maximum likelihood" + ] + }, + { + "cell_type": "markdown", + "id": "5e98247e-44ac-4ab9-ad1e-ff058dbb6215", + "metadata": {}, + "source": [ + "This implements a [continuous normalizing flow (CNF)](https://arxiv.org/abs/1806.07366) trained using maximum likelihood.\n", + "\n", + "To compute the likelihood of a sample $x_1$ we use the instantaneous change of variables formula integrated over time that is we have\n", + "\n", + "$$\n", + "\\begin{pmatrix}\n", + "\\partial x_t / \\partial t \\\\\n", + "\\partial \\log p(x_t) / \\partial t\n", + "\\end{pmatrix} = \n", + " \\begin{pmatrix}\n", + "f(t, x_t)\\\\\n", + "-\\text{tr}(\\partial f / \\partial x_t)\n", + "\\end{pmatrix} \n", + "$$\n", + "\n", + "which is implemented as a $d+1$ dimensional system. There are two common ways to calculate $\\partial \\log p(x_t) / \\partial t$.\n", + "* Exact calcuation of the trace of the Jacobian with essentially $D$ calls of $f$.\n", + "* Hutchinson trace estimator either with a normal distribution or Rademacher distribution. Which uses\n", + " $$\n", + " \\text{tr}(\\partial f / \\partial x_t) = \\mathbb{E}_{\\epsilon} \\left [ \\epsilon^T [\\partial f / \\partial x_t] \\epsilon \\right ]\n", + " $$\n", + " and can be used with a single call to $f$. $\\epsilon$ must be distributed such that $\\mathbb{E}(\\epsilon) = 0$ and $\\text{Cov}(\\epsilon) = I$. Most often Gaussian or Rademacher distributions are used, and are both implemented here.\n", + "\n", + "As compared to flow matching methods, this requires a calculation of the trace of the Hessian and backpropagation through time so is signficantly slower and more numerically unstable to train.\n", + "\n", + "Note: Requires a version of torch with `vmap` and `torch.func.jacrev`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2035a615", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import torch\n", + "from torchdyn.core import NeuralODE\n", + "from torchcfm.models import MLP\n", + "from torchcfm.utils import sample_moons, plot_trajectories\n", + "from torch.distributions import MultivariateNormal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b40e404c-ef9a-4242-8c47-ca8e7a197271", + "metadata": {}, + "outputs": [], + "source": [ + "class torch_wrapper(torch.nn.Module):\n", + " \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n", + "\n", + " def __init__(self, model):\n", + " super().__init__()\n", + " self.model = model\n", + "\n", + " def forward(self, t, x, *args, **kwargs):\n", + " return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n", + "\n", + "def exact_div_fn(u):\n", + " \"\"\"Accepts a function u:R^D -> R^D.\"\"\"\n", + " J = torch.func.jacrev(u)\n", + " return lambda x, *args: torch.trace(J(x))\n", + " \n", + "def div_fn_hutch_trace(u):\n", + " def div_fn(x, eps):\n", + " _, vjpfunc = torch.func.vjp(u, x)\n", + " return (vjpfunc(eps)[0] * eps).sum()\n", + " return div_fn\n", + " \n", + "class cnf_wrapper(torch.nn.Module):\n", + " \"\"\"Wraps model to a torchdyn compatible CNF format.\n", + " Appends an additional dimension representing the change in likelihood\n", + " over time.\n", + " \"\"\"\n", + "\n", + " def __init__(self, model, likelihood_estimator=\"exact\"):\n", + " super().__init__()\n", + " self.model = model\n", + " self.div_fn, self.eps_fn = self.get_div_and_eps(likelihood_estimator)\n", + "\n", + " def get_div_and_eps(self, likelihood_estimator):\n", + " if likelihood_estimator == \"exact\":\n", + " return exact_div_fn, None\n", + " if likelihood_estimator == 'hutch_gaussian':\n", + " return div_fn_hutch_trace, torch.randn_like\n", + " if likelihood_estimator == 'hutch_rademacher':\n", + " eps_fn = lambda x: torch.randint_like(x, low=0, high=2).float() * 2 - 1.\n", + " return div_fn_hutch_trace, eps_fn\n", + " raise NotImplementedError(f\"likelihood estimator {likelihood_estimator} is not implemented\")\n", + " \n", + " def forward(self, t, x, *args, **kwargs):\n", + " t = t.squeeze()\n", + " x = x[..., :-1]\n", + " def vecfield(y):\n", + " return self.model(torch.cat([y, t[None]]))\n", + " if self.eps_fn is None:\n", + " div = torch.vmap(self.div_fn(vecfield))(x)\n", + " else:\n", + " div = torch.vmap(self.div_fn(vecfield))(x, self.eps_fn(x))\n", + " dx = self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n", + " return torch.cat([dx, div[:,None]], dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faf18883", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "dim = 2\n", + "batch_size = 256\n", + "model = MLP(dim=dim, time_varying=True).to(device)\n", + "prior = MultivariateNormal(torch.zeros(dim, device=device), torch.eye(dim, device=device))\n", + "optimizer = torch.optim.Adam(model.parameters())\n", + "steps = 100\n", + "cnf = NeuralODE(\n", + " cnf_wrapper(model, likelihood_estimator=\"exact\"), solver=\"euler\", sensitivity=\"adjoint\"\n", + ")\n", + "node = NeuralODE(\n", + " torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\"\n", + ")\n", + "\n", + "start = time.time()\n", + "for k in range(2000):\n", + " optimizer.zero_grad()\n", + " x1 = sample_moons(batch_size).to(device)\n", + " x1_with_ll = torch.cat([x1, torch.zeros(batch_size, 1, device=device)], dim=-1)\n", + " x0_with_ll = cnf.trajectory(x1_with_ll, t_span=torch.linspace(1, 0, steps + 1, device=device))[-1]\n", + " logprob = prior.log_prob(x0_with_ll[..., :-1]) + x0_with_ll[...,-1]\n", + " loss = -torch.mean(logprob)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if (k + 1) % 200 == 0:\n", + " end = time.time()\n", + " print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n", + " start = end\n", + "\n", + " with torch.no_grad():\n", + " traj = node.trajectory(\n", + " torch.randn(1024, 2, device=device),\n", + " t_span=torch.linspace(0, 1, steps + 1, device=device),\n", + " )\n", + " plot_trajectories(traj.cpu().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1cbe018-696c-4157-b4ac-29fa50180f91", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}