-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfigure_eigvals_activities.py
114 lines (94 loc) · 3.14 KB
/
figure_eigvals_activities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm
from rnn import RNN2L
from visualizations import plot_eigenvalues
jax.config.update("jax_enable_x64", True)
plt.rcParams.update({'font.size': 40})
results_dir = 'results'
os.makedirs(results_dir, exist_ok=True)
seed = 2
N0_eigvals = 5
P_eigvals = 200
N0_matrix = 8
P_matrix = 8
N0 = 100
P = 200
steps_init = 200
steps_sim = 20
sigma_all = [0.5, 4.0, 4.0, 1.0]
sigma_mu_all = [0.5, 0.5, 6.0, 5.0]
labels = ['quiescent', '$\\mu$', '$\\mu + M$', '$M$']
lw = 5
ylim = [-1.05, 1.05]
ylim_diff = [-2.05, 2.05]
N_plot = 5
colors = plt.cm.Dark2.colors
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=colors)
net = RNN2L()
def decompose(X):
M = net.population_activities(X, N0, P)
O = jnp.ones( (N0,) )
Xm = jnp.kron(M, O)
X = X - Xm
return Xm, X
fig, axs = plt.subplots(4, len(sigma_all), figsize=(len(sigma_all)*5, 17), constrained_layout=True)
i = 0
for sigma, sigma_mu, label in tqdm( zip(sigma_all, sigma_mu_all, labels) ):
axs[0,i].set_title(label)
# Keep the same seed in each column for consistent phase transition [?]
key = jax.random.PRNGKey(seed)
### Matrix visualization
# Generate weights
key, subkey = jax.random.split(key)
J = net.generate_weights(subkey,
N0=N0_matrix,
P=P_matrix,
sigma=sigma,
sigma_mu=sigma_mu)
# Visualize weights
axs[0,i].imshow(J, cmap='plasma', vmin=-1.5, vmax=1.5)
axs[0,i].axis('off');
### Eigenvalues
# Generate weights
J = net.generate_weights(subkey,
N0=N0_eigvals,
P=P_eigvals,
sigma=sigma,
sigma_mu=sigma_mu)
# Plot eigenvalues
plot_eigenvalues(J, color='black', color_circle='red', title=None, ax=axs[1,i])
axs[1,i].axis('off')
### Activities and perturbations
# Generate weights
key, subkey = jax.random.split(key)
J = net.generate_weights(subkey,
N0=N0,
P=P,
sigma=sigma,
sigma_mu=sigma_mu)
## Activities
N = J.shape[0]
key, subkey = jax.random.split(key)
x0 = jax.random.normal(subkey, shape = (N,), dtype=jnp.float64)
## Equilibrate
x = net.evolve(J, x0, steps_init, save_trajectory=False)
## Evolve the original replica:
_, Xall = net.evolve(J, x, steps_sim, save_trajectory=True)
# Plot activities of individual neurons
axs[2,i].plot(Xall[:,:N_plot], lw=lw, alpha=0.7)
axs[2,i].set_ylim(ylim)
# Plot activities of populations
Mall = net.population_activities(Xall, N0, P)
axs[3,i].plot(Mall[:,:N_plot], lw=lw, alpha=0.7)
axs[3,i].set_ylim(ylim)
i += 1
axs[2,0].set_ylabel('$x^{1}_i$')
axs[3,0].set_ylabel('$m^{\\alpha}$')
for i in range(4):
axs[3, i].set_xlabel('$t$')
#fig.supxlabel('$t$')
plt.savefig(os.path.join(results_dir, 'eig_and_activities.pdf'), bbox_inches='tight')