-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
197 additions
and
0 deletions.
There are no files selected for viewing
197 changes: 197 additions & 0 deletions
197
examples/2D_tutorials/Maximum_likelihood_CNF_tutorial.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fb2b2856", | ||
"metadata": {}, | ||
"source": [ | ||
"# Continuous Normalizing Flow tutorial: training ODE generative models using maximum likelihood" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5e98247e-44ac-4ab9-ad1e-ff058dbb6215", | ||
"metadata": {}, | ||
"source": [ | ||
"This implements a [continuous normalizing flow (CNF)](https://arxiv.org/abs/1806.07366) trained using maximum likelihood.\n", | ||
"\n", | ||
"To compute the likelihood of a sample $x_1$ we use the instantaneous change of variables formula integrated over time that is we have\n", | ||
"\n", | ||
"$$\n", | ||
"\\begin{pmatrix}\n", | ||
"\\partial x_t / \\partial t \\\\\n", | ||
"\\partial \\log p(x_t) / \\partial t\n", | ||
"\\end{pmatrix} = \n", | ||
" \\begin{pmatrix}\n", | ||
"f(t, x_t)\\\\\n", | ||
"-\\text{tr}(\\partial f / \\partial x_t)\n", | ||
"\\end{pmatrix} \n", | ||
"$$\n", | ||
"\n", | ||
"which is implemented as a $d+1$ dimensional system. There are two common ways to calculate $\\partial \\log p(x_t) / \\partial t$.\n", | ||
"* Exact calcuation of the trace of the Jacobian with essentially $D$ calls of $f$.\n", | ||
"* Hutchinson trace estimator either with a normal distribution or Rademacher distribution. Which uses\n", | ||
" $$\n", | ||
" \\text{tr}(\\partial f / \\partial x_t) = \\mathbb{E}_{\\epsilon} \\left [ \\epsilon^T [\\partial f / \\partial x_t] \\epsilon \\right ]\n", | ||
" $$\n", | ||
" and can be used with a single call to $f$. $\\epsilon$ must be distributed such that $\\mathbb{E}(\\epsilon) = 0$ and $\\text{Cov}(\\epsilon) = I$. Most often Gaussian or Rademacher distributions are used, and are both implemented here.\n", | ||
"\n", | ||
"As compared to flow matching methods, this requires a calculation of the trace of the Hessian and backpropagation through time so is signficantly slower and more numerically unstable to train.\n", | ||
"\n", | ||
"Note: Requires a version of torch with `vmap` and `torch.func.jacrev`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2035a615", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import time\n", | ||
"import torch\n", | ||
"from torchdyn.core import NeuralODE\n", | ||
"from torchcfm.models import MLP\n", | ||
"from torchcfm.utils import sample_moons, plot_trajectories\n", | ||
"from torch.distributions import MultivariateNormal" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b40e404c-ef9a-4242-8c47-ca8e7a197271", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class torch_wrapper(torch.nn.Module):\n", | ||
" \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n", | ||
"\n", | ||
" def __init__(self, model):\n", | ||
" super().__init__()\n", | ||
" self.model = model\n", | ||
"\n", | ||
" def forward(self, t, x, *args, **kwargs):\n", | ||
" return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n", | ||
"\n", | ||
"def exact_div_fn(u):\n", | ||
" \"\"\"Accepts a function u:R^D -> R^D.\"\"\"\n", | ||
" J = torch.func.jacrev(u)\n", | ||
" return lambda x, *args: torch.trace(J(x))\n", | ||
" \n", | ||
"def div_fn_hutch_trace(u):\n", | ||
" def div_fn(x, eps):\n", | ||
" _, vjpfunc = torch.func.vjp(u, x)\n", | ||
" return (vjpfunc(eps)[0] * eps).sum()\n", | ||
" return div_fn\n", | ||
" \n", | ||
"class cnf_wrapper(torch.nn.Module):\n", | ||
" \"\"\"Wraps model to a torchdyn compatible CNF format.\n", | ||
" Appends an additional dimension representing the change in likelihood\n", | ||
" over time.\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" def __init__(self, model, likelihood_estimator=\"exact\"):\n", | ||
" super().__init__()\n", | ||
" self.model = model\n", | ||
" self.div_fn, self.eps_fn = self.get_div_and_eps(likelihood_estimator)\n", | ||
"\n", | ||
" def get_div_and_eps(self, likelihood_estimator):\n", | ||
" if likelihood_estimator == \"exact\":\n", | ||
" return exact_div_fn, None\n", | ||
" if likelihood_estimator == 'hutch_gaussian':\n", | ||
" return div_fn_hutch_trace, torch.randn_like\n", | ||
" if likelihood_estimator == 'hutch_rademacher':\n", | ||
" eps_fn = lambda x: torch.randint_like(x, low=0, high=2).float() * 2 - 1.\n", | ||
" return div_fn_hutch_trace, eps_fn\n", | ||
" raise NotImplementedError(f\"likelihood estimator {likelihood_estimator} is not implemented\")\n", | ||
" \n", | ||
" def forward(self, t, x, *args, **kwargs):\n", | ||
" t = t.squeeze()\n", | ||
" x = x[..., :-1]\n", | ||
" def vecfield(y):\n", | ||
" return self.model(torch.cat([y, t[None]]))\n", | ||
" if self.eps_fn is None:\n", | ||
" div = torch.vmap(self.div_fn(vecfield))(x)\n", | ||
" else:\n", | ||
" div = torch.vmap(self.div_fn(vecfield))(x, self.eps_fn(x))\n", | ||
" dx = self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n", | ||
" return torch.cat([dx, div[:,None]], dim=-1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "faf18883", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%time\n", | ||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
"dim = 2\n", | ||
"batch_size = 256\n", | ||
"model = MLP(dim=dim, time_varying=True).to(device)\n", | ||
"prior = MultivariateNormal(torch.zeros(dim, device=device), torch.eye(dim, device=device))\n", | ||
"optimizer = torch.optim.Adam(model.parameters())\n", | ||
"steps = 100\n", | ||
"cnf = NeuralODE(\n", | ||
" cnf_wrapper(model, likelihood_estimator=\"exact\"), solver=\"euler\", sensitivity=\"adjoint\"\n", | ||
")\n", | ||
"node = NeuralODE(\n", | ||
" torch_wrapper(model), solver=\"euler\", sensitivity=\"adjoint\"\n", | ||
")\n", | ||
"\n", | ||
"start = time.time()\n", | ||
"for k in range(2000):\n", | ||
" optimizer.zero_grad()\n", | ||
" x1 = sample_moons(batch_size).to(device)\n", | ||
" x1_with_ll = torch.cat([x1, torch.zeros(batch_size, 1, device=device)], dim=-1)\n", | ||
" x0_with_ll = cnf.trajectory(x1_with_ll, t_span=torch.linspace(1, 0, steps + 1, device=device))[-1]\n", | ||
" logprob = prior.log_prob(x0_with_ll[..., :-1]) + x0_with_ll[...,-1]\n", | ||
" loss = -torch.mean(logprob)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" if (k + 1) % 200 == 0:\n", | ||
" end = time.time()\n", | ||
" print(f\"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}\")\n", | ||
" start = end\n", | ||
"\n", | ||
" with torch.no_grad():\n", | ||
" traj = node.trajectory(\n", | ||
" torch.randn(1024, 2, device=device),\n", | ||
" t_span=torch.linspace(0, 1, steps + 1, device=device),\n", | ||
" )\n", | ||
" plot_trajectories(traj.cpu().numpy())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e1cbe018-696c-4157-b4ac-29fa50180f91", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |