diff --git a/src/graphics/utils.py b/src/graphics/utils.py index 65aaa8c..c126998 100644 --- a/src/graphics/utils.py +++ b/src/graphics/utils.py @@ -13,9 +13,13 @@ def setup_tueplots( rel_width: float = 1.0, hw_ratio: Optional[float] = None, inc_font_size: int = 0, + use_tex: bool = True, **kwargs ): - font_config = fonts.iclr2023_tex(family='serif') + if use_tex: + font_config = fonts.iclr2023_tex(family='serif') + else: + font_config = fonts.iclr2023(family='serif') if hw_ratio is not None: kwargs['height_to_width_ratio'] = hw_ratio size = figsizes.iclr2023(rel_width=rel_width, nrows=nrows, ncols=ncols, **kwargs) diff --git a/src/scripts/plots/ring/ellipses.py b/src/scripts/plots/ring/ellipses.py deleted file mode 100644 index afa58d6..0000000 --- a/src/scripts/plots/ring/ellipses.py +++ /dev/null @@ -1,216 +0,0 @@ -import argparse -import os.path -from typing import Optional - -import matplotlib as mpl -import numpy as np -from scipy import special -import torch -from matplotlib import pyplot as plt -from sklearn.preprocessing import StandardScaler - -from datasets.loaders import load_artificial_dataset -from graphics.distributions import kde_samples_hmap -from graphics.utils import setup_tueplots -from pcs.models import TensorizedPC, PC, MonotonicPC -from scripts.utils import setup_model, setup_data_loaders - -parser = argparse.ArgumentParser( - description="PDFs plotter" -) -parser.add_argument('--checkpoint-path', default='checkpoints', type=str, help="The checkpoints path") -parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title") - - -def ring_kde() -> np.ndarray: - splits = load_artificial_dataset('ring', num_samples=500, dtype=np.dtype(np.float64)) - data = np.concatenate(splits, axis=0) - scaler = StandardScaler() - data = scaler.fit_transform(data) - data_min, data_max = np.min(data, axis=0), np.max(data, axis=0) - #drange = np.abs(data_max - data_min) - #data_min, data_max = (data_min - drange * 0.05), (data_max + drange * 0.05) - xlim, ylim = [(data_min[i], data_max[i]) for i in range(len(data_min))] - return kde_samples_hmap(data, xlim=xlim, ylim=ylim, bandwidth=0.16) - - -def format_model_name(m: str, num_components: int) -> str: - if m == 'MonotonicPC': - return f"GMM ($K \! = \! {num_components}$)" - elif m == 'BornPC': - return f"NGMM ($K \! = \! {num_components}$)" - return m - - -def load_mixture( - model_name: str, - exp_id_fmt: str, - num_components: int, - learning_rate: float = 5e-3, - batch_size: int = 64 -) -> TensorizedPC: - metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) - model: TensorizedPC = setup_model(model_name, metadata, num_components=num_components) - exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) - filepath = os.path.join(args.checkpoint_path, 'ring', model_name, exp_id, 'model.pt') - state_dict = torch.load(filepath, map_location='cpu') - model.load_state_dict(state_dict['weights']) - return model - - -def load_pdf( - model: str, - exp_id_fmt: str, - num_components, - learning_rate: float = 5e-3, - batch_size: int = 64 -) -> np.ndarray: - exp_id = exp_id_fmt.format(num_components, learning_rate, batch_size) - filepath = os.path.join(args.checkpoint_path, 'ring', model, exp_id, 'distbest.npy') - return np.load(filepath) - - -def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes): - mus = mixture.input_layer.mu[0, :, 0, :].detach().numpy() - covs = np.exp(2 * mixture.input_layer.log_sigma[0, :, 0, :].detach().numpy()) - num_components = mus.shape[-1] - mix_weights = mixture.layers[-1].weight[0, 0].detach().numpy() - if isinstance(mixture, MonotonicPC): - mix_weights = special.softmax(mix_weights) - mix_weights = mix_weights / np.max(mix_weights) - else: - # assert num_components == 2 - # mix_weights = np.array([mix_weights[0] ** 2, mix_weights[1] ** 2, 2.0 * mix_weights[0] * mix_weights[1]]) - # # Products of Gaussian pdfs - # new_covs = np.array([ - # covs[:, 0] / 2.0, - # covs[:, 1] / 2.0, - # (covs[:, 0] * covs[:, 1]) / (covs[:, 0] + covs[:, 1]) - # ]).T # New covariances - # new_mus = np.array([ - # mus[:, 0], - # mus[:, 1], - # new_covs[:, 2] * (mus[:, 0] / covs[:, 0] + mus[:, 1] / covs[:, 1]) - # ]).T # New means - # num_components = 3 - # mus = new_mus - # covs = new_covs - mix_weights = -mix_weights / np.max(np.abs(mix_weights)) - for i in range(num_components): - mu = mus[:, i] - cov = np.diag(covs[:, i]) - v, w = np.linalg.eigh(cov) - v = 2.0 * np.sqrt(2.0) * np.sqrt(v) - ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.8, fill=False) - ell_dot = mpl.patches.Circle(mu, radius=0.03, fill=True) - if isinstance(mixture, MonotonicPC): - ell.set_alpha(mix_weights[i]) - ell.set_color('red') - ell_dot.set_alpha(0.5 * mix_weights[i]) - ell_dot.set_color('red') - else: - if mix_weights[i] <= 0.0: - ell.set_alpha(min(1.0, 3 * np.abs(mix_weights[i]))) - #ell.set_color('#E53935') - ell.set_linestyle('dotted') - ell.set_color('red') - ell_dot.set_alpha(0.5 * np.abs(mix_weights[i])) - #ell_dot.set_color('#E53935') - ell_dot.set_color('red') - else: - ell.set_alpha(mix_weights[i]) - ell.set_color('red') - ell_dot.set_alpha(0.5 * mix_weights[i]) - ell_dot.set_color('red') - ax.add_artist(ell) - ax.add_artist(ell_dot) - - -def plot_pdf( - pdf: np.ndarray, - metadata: dict, - ax: plt.Axes, vmin: - Optional[float] = None, - vmax: Optional[float] = None -): - x_lim = metadata['domains'][0] - y_lim = metadata['domains'][1] - x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) - y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) - - x_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) - y_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) - - xi, yi = np.mgrid[range(pdf.shape[0]), range(pdf.shape[1])] - xi = (xi + 0.5) / pdf.shape[0] - yi = (yi + 0.5) / pdf.shape[1] - xi = xi * (x_lim[1] - x_lim[0]) + x_lim[0] - yi = yi * (y_lim[1] - y_lim[0]) + y_lim[0] - ax.pcolormesh(xi, yi, pdf, vmin=vmin, vmax=vmax) - - -if __name__ == '__main__': - args = parser.parse_args() - - models = [ - 'MonotonicPC', - 'MonotonicPC', - 'BornPC' - ] - - num_components = [2, 16, 2] - learning_rates = [5e-3, 5e-3, 1e-3] - - exp_id_formats = [ - 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', - 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU', - 'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN' - ] - - truth_pdf = ring_kde() - - mixtures = [ - load_mixture(m, eif, nc, lr) - for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) - ] - - pdfs = [ - load_pdf(m, eif, nc, lr) - for m, eif, nc, lr in zip(models, exp_id_formats, num_components, learning_rates) - ] - vmax = np.max([truth_pdf] + pdfs) - vmin = 0.0 - - metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000) - - os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True) - data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components)) - for idx, (p, pdf, m, nc) in enumerate(data_pdfs): - setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0) - fig, ax = plt.subplots(1, 1) - if args.title: - title = f"{format_model_name(m, nc)}" if p is not None else m - else: - title = None - - plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax) - if p is not None: - plot_mixture_ellipses(p, ax=ax) - - x_lim = metadata['domains'][0] - y_lim = metadata['domains'][1] - x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0)) - y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0)) - x_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) - y_lim = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1])) - - ax.set_xlim(*x_lim) - ax.set_ylim(*y_lim) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_aspect(1.0) - - if args.title: - ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center') - - plt.savefig(os.path.join('figures', 'gaussian-ring', f'pdfs-ellipses-{idx}.png'), dpi=1200) diff --git a/src/scripts/plots/ring/pdfs.py b/src/scripts/plots/ring/pdfs.py index 218d900..1de829d 100644 --- a/src/scripts/plots/ring/pdfs.py +++ b/src/scripts/plots/ring/pdfs.py @@ -77,27 +77,28 @@ def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes): cov = np.diag(covs[:, i]) v, w = np.linalg.eigh(cov) v = 2.0 * np.sqrt(2.0) * np.sqrt(v) - ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.7, fill=False) - ell_dot = mpl.patches.Circle(mu, radius=0.03, fill=True) - ell.set_color('red') + ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.8, fill=False) + ell_dot = mpl.patches.Circle(mu, radius=0.02, fill=True) + ell.set_color('#E53935') if isinstance(mixture, MonotonicPC): #ell.set_alpha(mix_weights[i]) #ell_dot.set_alpha(0.5 * mix_weights[i]) - ell_dot.set_color('red') - ell.set_alpha(0.775) - ell_dot.set_alpha(0.775) + ell_dot.set_color('#E53935') + # ell.set_alpha(0.775) + # ell_dot.set_alpha(0.775) else: if mix_weights[i] <= 0.0: #ell.set_alpha(min(1.0, 3 * np.abs(mix_weights[i]))) ell.set_linestyle('dotted') + ell.set_linewidth(1.5) #ell_dot.set_alpha(0.5 * np.abs(mix_weights[i])) - #%ell_dot.set_color('red') + ell_dot.set_color('#E53935') else: #ell.set_alpha(mix_weights[i]) #ell_dot.set_alpha(0.5 * mix_weights[i]) - ell_dot.set_color('red') - ell.set_alpha(0.85) - ell_dot.set_alpha(0.85) + ell_dot.set_color('#E53935') + # ell.set_alpha(0.85) + # ell_dot.set_alpha(0.85) ax.add_artist(ell) ax.add_artist(ell_dot) diff --git a/src/scripts/plots/wdist.py b/src/scripts/plots/wdist.py index 7201f55..b1616ff 100644 --- a/src/scripts/plots/wdist.py +++ b/src/scripts/plots/wdist.py @@ -49,22 +49,30 @@ path = setup_experiment_path( args.path, args.dataset, args.model, args.exp_alias, trial_id=build_run_id(args)) sd = torch.load(os.path.join(path, 'model.pt'), map_location='cpu')['weights'] + print(sd.keys()) # Concatenate weights in a large vector ws = list() for k in sd.keys(): - if 'layer' in k and 'weight' in k: - ws.append(sd[k].flatten().numpy()) + # Select the parameters of CP layers only + if 'layer' in k and 'weight' in k and 'input' not in k and 'mixture' not in k: + w = sd[k] + if 'Born' in args.model: # Perform squaring + if len(w.shape) == 3: # CP layer + w = torch.einsum('fki,fkj->fkij', w, w) + else: + assert False, "This should not happen :(" + ws.append(w.flatten().numpy()) ws = np.concatenate(ws, axis=0) # Preprocess the weights, and set some flags if 'Mono' in args.model: - mb = np.quantile(ws, q=[0.9999]) + mb = np.quantile(ws, q=[0.99], method='lower') ws = ws[ws <= mb] ws = np.exp(ws) hcol = 'C0' elif 'Born' in args.model: - ma, mb = np.quantile(ws, q=[0.0005, 0.9995]) + ma, mb = np.quantile(ws, q=[0.005, 0.995], method='lower') ws = ws[(ws >= ma) & (ws <= mb)] hcol = 'C1' print(ws.shape) @@ -72,7 +80,7 @@ # Compute and plot the instogram setup_tueplots(1, 1, rel_width=0.25, hw_ratio=1.0) hlabel = f'{format_model(args.model)}' - plt.hist(ws, density=True, bins=64, color=hcol, label=hlabel) + plt.hist(ws, bins=64, color=hcol, label=hlabel) plt.yscale('log') if args.legend: plt.legend()