From 6072dda338009e8715532110e8f9e5f2a3fd5750 Mon Sep 17 00:00:00 2001 From: Jarrid Rector-Brooks Date: Tue, 28 Jan 2025 12:22:37 -0500 Subject: [PATCH] Make a small change to readme, make sure full_batch W2 doesn't crash things with tons of samples, and fix some logging --- README.md | 5 ++++- dem/energies/multi_double_well_energy.py | 19 +++++++++++++++++-- dem/models/dem_module.py | 11 ++++++----- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index d38135d..4944c23 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,10 @@ python dem/eval.py experiment=lj55_idem ckpt_path= ``` This will take some time to run and will generate a file named `samples_.pt` in the hydra -runtime directory for the eval run. We can now use these samples to train a CFM model. We provide a config `lj55_idem_cfm` +runtime directory for the eval run. The eval run will also log keys `test/full_batch/*` and `test/*` to wandb. +For GMM, DW4 and LJ13 you can refer to the `test/2-Wasserstein` and `test/dist_total_var` keys to reproduce our paper +numbers while for LJ55 refer to the `test/full_batch/2-Wasserstein` and `test/full_batch/dist_total_var` keys. +We can now use these samples to train a CFM model. We provide a config `lj55_idem_cfm` which has the settings to enable the CFM pipeline to run by default for the LJ55 task, though doing so for other tasks is also simple. The main config changes required are to set `model.debug_use_train_data=true, model.nll_with_cfm=true` and `model.logz_with_cfm=true`. To point the CFM training run to the dataset generated from iDEM samples we can set the diff --git a/dem/energies/multi_double_well_energy.py b/dem/energies/multi_double_well_energy.py index d6bb7da..e9e287a 100644 --- a/dem/energies/multi_double_well_energy.py +++ b/dem/energies/multi_double_well_energy.py @@ -1,3 +1,4 @@ +from io import BytesIO from typing import Optional import matplotlib.pyplot as plt @@ -239,5 +240,19 @@ def get_dataset_fig(self, samples): axs[1].set_xlabel("Energy") axs[1].legend() - fig.canvas.draw() - return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) + try: + buffer = BytesIO() + fig.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0) + buffer.seek(0) + + return PIL.Image.open(buffer) + + except Exception as e: + fig.canvas.draw() + return PIL.Image.frombytes( + "RGB", fig.canvas.get_width_height(), fig.canvas.renderer.buffer_rgba() + ) + fig.canvas.draw() + return PIL.Image.frombytes( + "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() + ) diff --git a/dem/models/dem_module.py b/dem/models/dem_module.py index df21388..202e5d4 100644 --- a/dem/models/dem_module.py +++ b/dem/models/dem_module.py @@ -985,10 +985,10 @@ def on_test_epoch_end(self) -> None: wandb_logger = get_wandb_logger(self.loggers) self.eval_epoch_end("test") - # self._log_energy_w2(prefix="test") - # if self.energy_function.is_molecule: - # self._log_dist_w2(prefix="test") - # self._log_dist_total_var(prefix="test") + self._log_energy_w2(prefix="test") + if self.energy_function.is_molecule: + self._log_dist_w2(prefix="test") + self._log_dist_total_var(prefix="test") if self.nll_with_cfm: self._cfm_test_epoch_end() @@ -1020,8 +1020,9 @@ def on_test_epoch_end(self) -> None: final_samples = torch.cat(final_samples, dim=0) print("Computing large batch distribution distances") + idx = torch.randperm(len(final_samples))[:10000] names, dists = compute_full_dataset_distribution_distances( - self.energy_function.unnormalize(final_samples)[:, None], + self.energy_function.unnormalize(final_samples)[idx, None], test_set[:, None], self.energy_function, )