Skip to content

Commit

Permalink
Add CNF
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 committed Nov 17, 2024
1 parent 62c44af commit beeacb4
Showing 1 changed file with 197 additions and 0 deletions.
197 changes: 197 additions & 0 deletions examples/2D_tutorials/Maximum_likelihood_CNF_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fb2b2856",
"metadata": {},
"source": [
"# Continuous Normalizing Flow tutorial: training ODE generative models using maximum likelihood"
]
},
{
"cell_type": "markdown",
"id": "5e98247e-44ac-4ab9-ad1e-ff058dbb6215",
"metadata": {},
"source": [
"This implements a [continuous normalizing flow (CNF)](https://arxiv.org/abs/1806.07366) trained using maximum likelihood.\n",
"\n",
"To compute the likelihood of a sample $x_1$ we use the instantaneous change of variables formula integrated over time that is we have\n",
"\n",
"$$\n",
"\\begin{pmatrix}\n",
"\\partial x_t / \\partial t \\\\\n",
"\\partial \\log p(x_t) / \\partial t\n",
"\\end{pmatrix} = \n",
" \\begin{pmatrix}\n",
"f(t, x_t)\\\\\n",
"-\\text{tr}(\\partial f / \\partial x_t)\n",
"\\end{pmatrix} \n",
"$$\n",
"\n",
"which is implemented as a $d+1$ dimensional system. There are two common ways to calculate $\\partial \\log p(x_t) / \\partial t$.\n",
"* Exact calcuation of the trace of the Jacobian with essentially $D$ calls of $f$.\n",
"* Hutchinson trace estimator either with a normal distribution or Rademacher distribution. Which uses\n",
" $$\n",
" \\text{tr}(\\partial f / \\partial x_t) = \\mathbb{E}_{\\epsilon} \\left [ \\epsilon^T [\\partial f / \\partial x_t] \\epsilon \\right ]\n",
" $$\n",
" and can be used with a single call to $f$. $\\epsilon$ must be distributed such that $\\mathbb{E}(\\epsilon) = 0$ and $\\text{Cov}(\\epsilon) = I$. Most often Gaussian or Rademacher distributions are used, and are both implemented here.\n",
"\n",
"As compared to flow matching methods, this requires a calculation of the trace of the Hessian and backpropagation through time so is signficantly slower and more numerically unstable to train.\n",
"\n",
"Note: Requires a version of torch with `vmap` and `torch.func.jacrev`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2035a615",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import torch\n",
"from torchdyn.core import NeuralODE\n",
"from torchcfm.models import MLP\n",
"from torchcfm.utils import sample_moons, plot_trajectories\n",
"from torch.distributions import MultivariateNormal"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b40e404c-ef9a-4242-8c47-ca8e7a197271",
"metadata": {},
"outputs": [],
"source": [
"class torch_wrapper(torch.nn.Module):\n",
" \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n",
"\n",
" def __init__(self, model):\n",
" super().__init__()\n",
" self.model = model\n",
"\n",
" def forward(self, t, x, *args, **kwargs):\n",
" return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
"\n",
"def exact_div_fn(u):\n",
" \"\"\"Accepts a function u:R^D -> R^D.\"\"\"\n",
" J = torch.func.jacrev(u)\n",
" return lambda x, *args: torch.trace(J(x))\n",
" \n",
"def div_fn_hutch_trace(u):\n",
" def div_fn(x, eps):\n",
" _, vjpfunc = torch.func.vjp(u, x)\n",
" return (vjpfunc(eps)[0] * eps).sum()\n",
" return div_fn\n",
" \n",
"class cnf_wrapper(torch.nn.Module):\n",
" \"\"\"Wraps model to a torchdyn compatible CNF format.\n",
" Appends an additional dimension representing the change in likelihood\n",
" over time.\n",
" \"\"\"\n",
"\n",
" def __init__(self, model, likelihood_estimator=\"exact\"):\n",
" super().__init__()\n",
" self.model = model\n",
" self.div_fn, self.eps_fn = self.get_div_and_eps(likelihood_estimator)\n",
"\n",
" def get_div_and_eps(self, likelihood_estimator):\n",
" if likelihood_estimator == \"exact\":\n",
" return exact_div_fn, None\n",
" if likelihood_estimator == 'hutch_gaussian':\n",
" return div_fn_hutch_trace, torch.randn_like\n",
" if likelihood_estimator == 'hutch_rademacher':\n",
" eps_fn = lambda x: torch.randint_like(x, low=0, high=2).float() * 2 - 1.\n",
" return div_fn_hutch_trace, eps_fn\n",
" raise NotImplementedError(f\"likelihood estimator {likelihood_estimator} is not implemented\")\n",
" \n",
" def forward(self, t, x, *args, **kwargs):\n",
" t = t.squeeze()\n",
" x = x[..., :-1]\n",
" def vecfield(y):\n",
" return self.model(torch.cat([y, t[None]]))\n",
" if self.eps_fn is None:\n",
" div = torch.vmap(self.div_fn(vecfield))(x)\n",
" else:\n",
" div = torch.vmap(self.div_fn(vecfield))(x, self.eps_fn(x))\n",
" dx = self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n",
" return torch.cat([dx, div[:,None]], dim=-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "faf18883",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"dim = 2\n",
"batch_size = 256\n",
"model = MLP(dim=dim, time_varying=True).to(device)\n",
"prior = MultivariateNormal(torch.zeros(dim, device=device), torch.eye(dim, device=device))\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"steps = 100\n",
"cnf = NeuralODE(\n",
" cnf_wrapper(model, likelihood_estimator=\"exact\"), solver=\"euler\", sensitivity=\"adjoint\"\n",
")\n",
"node = NeuralODE(\n",
" torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\"\n",
")\n",
"\n",
"start = time.time()\n",
"for k in range(2000):\n",
" optimizer.zero_grad()\n",
" x1 = sample_moons(batch_size).to(device)\n",
" x1_with_ll = torch.cat([x1, torch.zeros(batch_size, 1, device=device)], dim=-1)\n",
" x0_with_ll = cnf.trajectory(x1_with_ll, t_span=torch.linspace(1, 0, steps + 1, device=device))[-1]\n",
" logprob = prior.log_prob(x0_with_ll[..., :-1]) + x0_with_ll[...,-1]\n",
" loss = -torch.mean(logprob)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if (k + 1) % 200 == 0:\n",
" end = time.time()\n",
" print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n",
" start = end\n",
"\n",
" with torch.no_grad():\n",
" traj = node.trajectory(\n",
" torch.randn(1024, 2, device=device),\n",
" t_span=torch.linspace(0, 1, steps + 1, device=device),\n",
" )\n",
" plot_trajectories(traj.cpu().numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1cbe018-696c-4157-b4ac-29fa50180f91",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit beeacb4

Please sign in to comment.