Skip to content

Commit

Permalink
step-by-step gaussian component plotter
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Apr 10, 2024
1 parent 38191a2 commit 4c62014
Showing 1 changed file with 66 additions and 34 deletions.
100 changes: 66 additions & 34 deletions src/scripts/plots/ring/pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@
parser.add_argument('path', default='checkpoints', type=str, help="The checkpoints path")
parser.add_argument('--show-ellipses', default=False, action='store_true',
help="Whether to show the Gaussian components as ellipses")
parser.add_argument('--prog-ellipses', default=False, action='store_true',
help="Whether to plot ellipses progressively")
parser.add_argument('--title', default=False, action='store_true', help="Whether to show a title")
parser.add_argument('--vertical-title', default=False, action='store_true',
help="Whether to show the title vertically")
parser.add_argument('--dpi', type=int, default=192, help="The DPI for PNG rasterization")
parser.add_argument('--prune', default=False, action='store_true',
help="Whether to prune components having weight close to zero")


def format_model_name(m: str, num_components: int) -> str:
if m == 'MonotonicPC':
return f"GMM ($K \!\! = \!\! {num_components}$)"
elif m == 'BornPC':
return f"NGMM ($K \!\! = \!\! {num_components}$)"
return f"SGMM ($K \!\! = \!\! {num_components}$)"
return m


Expand Down Expand Up @@ -57,7 +64,7 @@ def load_pdf(
return np.load(filepath)


def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes):
def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes, max_num_components: Optional[int] = None):
mus = mixture.input_layer.mu[0, :, 0, :].detach().numpy()
covs = np.exp(2 * mixture.input_layer.log_sigma[0, :, 0, :].detach().numpy())
num_components = mus.shape[-1]
Expand All @@ -67,16 +74,26 @@ def plot_mixture_ellipses(mixture: TensorizedPC, ax: plt.Axes):
mix_weights = mix_weights / np.max(mix_weights)
else:
mix_weights = -mix_weights / np.max(np.abs(mix_weights))
for i in range(num_components):
if np.abs(mix_weights[i]) < 0.1:
if max_num_components is None:
ncomps = list(range(num_components))
else:
assert max_num_components <= num_components
sort_ord = np.argsort(np.arctan2(mus[0], mus[1]))
ncomps = np.arange(num_components)
ncomps = ncomps[sort_ord][:max_num_components].tolist()
for i in ncomps:
if args.prune and np.abs(mix_weights[i]) < 0.1:
continue
mu = mus[:, i]
cov = np.diag(covs[:, i])
v, w = np.linalg.eigh(cov)
v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
#alpha = 1.0 if i < max_num_components else 0.0
ell = mpl.patches.Ellipse(mu, v[0], v[1], linewidth=0.8, fill=False)
ell_dot = mpl.patches.Circle(mu, radius=0.02, fill=True)
ell.set_color('#E53935')
#ell.set_alpha(alpha)
#ell_dot.set_alpha(alpha)
if isinstance(mixture, MonotonicPC):
#ell.set_alpha(mix_weights[i])
#ell_dot.set_alpha(0.5 * mix_weights[i])
Expand Down Expand Up @@ -127,17 +144,20 @@ def plot_pdf(

if __name__ == '__main__':
args = parser.parse_args()
assert not args.prog_ellipses or (args.show_ellipses and args.prog_ellipses)

models = [
'MonotonicPC',
'MonotonicPC',
'MonotonicPC',
'BornPC'
]

num_components = [2, 16, 2]
learning_rates = [5e-3, 5e-3, 1e-3]
num_components = [1, 2, 16, 2]
learning_rates = [5e-3, 5e-3, 5e-3, 1e-3]

exp_id_formats = [
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU',
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU',
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IU',
'RGran_R1_K{}_D1_Lcp_OAdam_LR{}_BS{}_IN'
Expand All @@ -163,34 +183,46 @@ def plot_pdf(

metadata, _ = setup_data_loaders('ring', 'datasets', 1, num_samples=10000)

os.makedirs(os.path.join('figures', 'gaussian-ring'), exist_ok=True)
data_pdfs = [(None, truth_pdf, 'Ground Truth', -1)] + list(zip(mixtures, pdfs, models, num_components))
for idx, (p, pdf, m, nc) in enumerate(data_pdfs):
setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0)
fig, ax = plt.subplots(1, 1)
if args.title:
title = f"{format_model_name(m, nc)}" if p is not None else m
if args.prog_ellipses:
plot_settings = [{'max_num_components': i + 1} for i in range(nc)]
else:
title = None

plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax)
if p is not None and args.show_ellipses:
plot_mixture_ellipses(p, ax=ax)

x_lim = metadata['domains'][0]
y_lim = metadata['domains'][1]
x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0))
y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0))
#lims = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1]))

ax.set_xlim(*x_lim)
ax.set_ylim(*y_lim)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect(1.0)
if args.title:
#ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center')
ax.set_title(title, y=-0.275)

filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png'
plt.savefig(os.path.join('figures', 'gaussian-ring', filename), dpi=1200)
plot_settings = [{'max_num_components': None}]
for j, ps in enumerate(plot_settings):
setup_tueplots(1, 1, rel_width=0.2, hw_ratio=1.0)
fig, ax = plt.subplots(1, 1)
if args.title:
title = f"{format_model_name(m, nc)}" if p is not None else m
else:
title = None

plot_pdf(pdf, metadata, ax=ax, vmin=vmin, vmax=vmax)
if p is not None and args.show_ellipses:
plot_mixture_ellipses(p, ax=ax, **ps)

x_lim = metadata['domains'][0]
y_lim = metadata['domains'][1]
x_lim = (x_lim[0] * np.sqrt(2.0), x_lim[1] * np.sqrt(2.0))
y_lim = (y_lim[0] * np.sqrt(2.0), y_lim[1] * np.sqrt(2.0))
#lims = (min(x_lim[0], y_lim[0]), max(x_lim[1], y_lim[1]))

ax.set_xlim(*x_lim)
ax.set_ylim(*y_lim)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect(1.0)
if args.title:
if args.vertical_title:
ax.set_title(title, rotation='vertical', x=-0.1, y=0.41, va='center')
else:
ax.set_title(title, y=-0.275)

if args.prog_ellipses:
filename = f'pdfs-ellipses-{idx}-{j}.png' if args.show_ellipses else f'pdfs-{idx}-{j}.png'
subdir = 'progressive'
else:
filename = f'pdfs-ellipses-{idx}.png' if args.show_ellipses else f'pdfs-{idx}.png'
subdir = 'plain'
os.makedirs(os.path.join('figures', 'gaussian-ring', subdir), exist_ok=True)
plt.savefig(os.path.join('figures', 'gaussian-ring', subdir, filename), dpi=args.dpi)

0 comments on commit 4c62014

Please sign in to comment.