diff --git a/src/scripts/plots/ring/pdfs.py b/src/scripts/plots/ring/pdfs.py index 1f7b5c3..b9a2bd2 100644 --- a/src/scripts/plots/ring/pdfs.py +++ b/src/scripts/plots/ring/pdfs.py @@ -18,14 +18,21 @@ parser.add_argument('path', default='checkpoints', type=str, help="The checkpoints path") parser.add_argument('--show-ellipses', default=False, action='store_true', help="Whether to show the Gaussian components as ellipses") +parser.add_argument('--prog-ellipses', default=False, action='store_true', + help="Whether to plot ellipses progressively") parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title") +parser.add_argument('--vertical-title', default=False, action='store_true', + help="Whether to show the title vertically") +parser.add_argument('--dpi', type=int, default=192, help="The DPI for PNG rasterization") +parser.add_argument('--prune', default=False, action='store_true', + help="Whether to prune components having weight close to zero") def format_model_name(m: str, num_components: int) -> str: if m == 'MonotonicPC': return f"GMM ($K \!\! = \!\! {num_components}$)" elif m == 'BornPC': - return f"NGMM ($K \!\! = \!\! {num_components}$)" + return f"SGMM ($K \!\! = \!\! {num_components}$)" return m @@ -57,7 +64,7 @@ def load_pdf( return np.load(filepath) -def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes): +def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes, max_num_components: Optional[int] = None): mus = mixture.input_layer.mu[0, :, 0, :].detach().numpy() covs = np.exp(2 * mixture.input_layer.log_sigma[0, :, 0, :].detach().numpy()) num_components = mus.shape[-1] @@ -67,16 +74,26 @@ def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes): mix_weights = mix_weights / np.max(mix_weights) else: mix_weights = -mix_weights / np.max(np.abs(mix_weights)) - for i in range(num_components): - if np.abs(mix_weights[i]) < 0.1: + if max_num_components is None: + ncomps = list(range(num_components)) + else: + assert max_num_components <= num_components + sort_ord = np.argsort(np.arctan2(mus[0], mus[1])) + ncomps = np.arange(num_components) + ncomps = ncomps[sort_ord][:max_num_components].tolist() + for i in ncomps: + if args.prune and np.abs(mix_weights[i]) < 0.1: continue mu = mus[:, i] cov = np.diag(covs[:, i]) v, w = np.linalg.eigh(cov) v = 2.0 * np.sqrt(2.0) * np.sqrt(v) + #alpha = 1.0 if i < max_num_components else 0.0 ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.8, fill=False) ell_dot = mpl.patches.Circle(mu, radius=0.02, fill=True) ell.set_color('#E53935') + #ell.set_alpha(alpha) + #ell_dot.set_alpha(alpha) if isinstance(mixture, MonotonicPC): #ell.set_alpha(mix_weights[i]) #ell_dot.set_alpha(0.5 * mix_weights[i]) @@ -127,17 +144,20 @@ def plot_pdf( if __name__ == '__main__': args = parser.parse_args() + assert not args.prog_ellipses or (args.show_ellipses and args.prog_ellipses) models = [ + 'MonotonicPC', 'MonotonicPC', 'MonotonicPC', 'BornPC' ] - num_components = [2, 16, 2] - learning_rates = [5e-3, 5e-3, 1e-3] + num_components = [1, 2, 16, 2] + learning_rates = [5e-3, 5e-3, 5e-3, 1e-3] exp_id_formats = [ + 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN' @@ -163,34 +183,46 @@ def plot_pdf( metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) - os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True) data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components)) for idx, (p, pdf, m, nc) in enumerate(data_pdfs): - setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0) - fig, ax = plt.subplots(1, 1) - if args.title: - title = f"{format_model_name(m, nc)}" if p is not None else m + if args.prog_ellipses: + plot_settings = [{'max_num_components': i + 1} for i in range(nc)] else: - title = None - - plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax) - if p is not None and args.show_ellipses: - plot_mixture_ellipses(p, ax=ax) - - x_lim = metadata['domains'][0] - y_lim = metadata['domains'][1] - x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) - y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) - #lims = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) - - ax.set_xlim(*x_lim) - ax.set_ylim(*y_lim) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_aspect(1.0) - if args.title: - #ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center') - ax.set_title(title, y=-0.275) - - filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png' - plt.savefig(os.path.join('figures', 'gaussian-ring', filename), dpi=1200) + plot_settings = [{'max_num_components': None}] + for j, ps in enumerate(plot_settings): + setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0) + fig, ax = plt.subplots(1, 1) + if args.title: + title = f"{format_model_name(m, nc)}" if p is not None else m + else: + title = None + + plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax) + if p is not None and args.show_ellipses: + plot_mixture_ellipses(p, ax=ax, **ps) + + x_lim = metadata['domains'][0] + y_lim = metadata['domains'][1] + x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) + y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) + #lims = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) + + ax.set_xlim(*x_lim) + ax.set_ylim(*y_lim) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_aspect(1.0) + if args.title: + if args.vertical_title: + ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center') + else: + ax.set_title(title, y=-0.275) + + if args.prog_ellipses: + filename = f'pdfs-ellipses-{idx}-{j}.png' if args.show_ellipses else f'pdfs-{idx}-{j}.png' + subdir = 'progressive' + else: + filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png' + subdir = 'plain' + os.makedirs(os.path.join('figures', 'gaussian-ring', subdir), exist_ok=True) + plt.savefig(os.path.join('figures', 'gaussian-ring', subdir, filename), dpi=args.dpi)