-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathvisual_utils.py
125 lines (110 loc) · 4.82 KB
/
visual_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import time
import numpy as np
from os.path import join
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from scipy.spatial.distance import pdist
cmap = plt.get_cmap('tab10')
COLORS = [cmap(i) for i in range(10)]
# colors = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
# '#ff00ff', '#990000', '#999900', '#009900', '#009999']
def visualize(feat, labels, title, dir_path, filename):
plt.ion()
plt.clf()
for i in range(10):
plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=COLORS[i])
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
plt.title(title)
plt.savefig(join(dir_path, f"{filename}.jpg"))
plt.draw()
plt.pause(0.001)
def visualize_tsne_neighbors(feat, phrases, distance, title, dir_path, filename):
feat_unique, phrases_unique = [], []
phrases_seen = set()
for i in range(len(feat)):
if phrases[i] not in phrases_seen:
feat_unique.append(feat[i])
phrases_unique.append(phrases[i])
phrases_seen.add(phrases[i])
feat = np.vstack(feat_unique)
phrases = phrases_unique
time_start = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=30, n_iter=300, metric=distance.to_sklearn_metric())
tsne_results = tsne.fit_transform(feat)
print(f"t-SNE done! Time elapsed: {time.time() - time_start} seconds")
plt.ion()
plt.clf()
x = tsne_results[:, 0]
y = tsne_results[:, 1]
if phrases is not None:
with open(join(dir_path, f"{filename}-reference.txt"), 'w') as reffile:
centers = [1061, 999, 782, 2518, 94]
# centers = np.random.choice(feat.shape[0], 6, replace=False)
nn = NearestNeighbors(n_neighbors=5, metric=distance.to_sklearn_metric())
nn.fit(tsne_results)
distances, indices = nn.kneighbors(tsne_results[centers, :])
for i in range(len(centers)):
c = centers[i]
reffile.write(f"Phrase {c}: {phrases[c]}\nNeighbors:\n")
for j in range(len(indices[i])):
nid = indices[i, j]
if nid != c:
reffile.write(f"\tPhrase {nid} at distance {distances[i, j]}: {phrases[nid]}\n")
inds = [indices[i, j] for j in range(len(indices[i]))]
plt.plot(x[inds], y[inds], '.', c=COLORS[i])
plt.legend([str(center) for center in centers], loc='upper right')
plt.axhline(y=0, color='grey', ls=':')
plt.axvline(x=0, color='grey', ls=':')
plt.title(title)
plt.savefig(join(dir_path, f"{filename}.jpg"))
plt.draw()
plt.pause(0.001)
def visualize_tsne_speaker(feat, y, unique_labels, distance, title, dir_path, filename):
time_start = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=30, n_iter=300, metric=distance.to_sklearn_metric())
tsne_results = tsne.fit_transform(feat)
print(f"t-SNE done! Time elapsed: {time.time() - time_start} seconds")
plt.ion()
plt.clf()
plt.figure(figsize=(14, 10))
legends = []
for label, color in zip(unique_labels, COLORS):
# Calculate distances between original embeddings
dists = pdist(feat[y == label, :], metric=distance.to_sklearn_metric())
legends.append(f"{label}: μ={np.mean(dists):.2f} σ={np.std(dists):.2f}")
# Plot t-SNE embeddings
curfeat_tsne = tsne_results[y == label, :]
plt.plot(curfeat_tsne[:, 0], curfeat_tsne[:, 1], '.', c=color)
plt.legend(legends, loc='best', fontsize='medium')
plt.axhline(y=0, color='grey', ls=':')
plt.axvline(x=0, color='grey', ls=':')
plt.title(title)
plt.savefig(join(dir_path, f"{filename}.jpg"))
plt.draw()
plt.pause(0.001)
def plot_pred_hists(dists, y_true, title, dir_path, filename):
bins = np.arange(0, 1, step=0.005)
plt.ion()
plt.clf()
plt.hist([dist for dist, y in zip(dists, y_true) if y == 1], bins, alpha=0.5, label='Same', color='green')
plt.hist([dist for dist, y in zip(dists, y_true) if y == 0], bins, alpha=0.5, label='Different', color='red')
plt.legend(loc='upper right')
plt.title(title)
plt.savefig(join(dir_path, f"{filename}.jpg"))
plt.draw()
plt.pause(0.001)
def visualize_logs(exp_path: str, log_file_name: str, metric_name: str, color: str, title: str, plot_file_name: str):
with open(join(exp_path, log_file_name), 'r') as log_file:
data = [float(line.strip()) for line in log_file.readlines()]
plt.ion()
plt.clf()
plt.plot(range(1, len(data) + 1), data, c=color)
plt.xlabel('Epoch')
plt.ylabel(metric_name)
plt.title(title)
plt.savefig(join(exp_path, f"{plot_file_name}.jpg"))
plt.draw()
plt.pause(0.001)