forked from kaiu85/CRNs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_graphs_for_figure_1b.py
76 lines (46 loc) · 1.87 KB
/
create_graphs_for_figure_1b.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import pickle
# FOR CUDA DEBUGGING, c.f. https://lernapparat.de/debug-device-assert/
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "0"# Set to "1" for more verbose debugging messages
import torch
from network_constructors import create_wta_model
from utils import save_trajectories
import pickle
import numpy as np
import matplotlib.pyplot as plt
# Select GPU
torch.cuda.set_device(0)
print_debug = False
steps = 300000
log_every = 100
n = 1000
cont = False
forcings = [100, 500]
max_plots = 10
cm = 1/2.54 # centimeters in inches
### Simulation
for forcing in forcings:
suffix = 'wta_demo_network_' + str(forcing)
CRN = create_wta_model(mean_forcing = forcing)
CRN.A[:,:] = torch.randint(0,500,CRN.A.shape)
CRN.init_global_reaction_variables(print_debug = print_debug)
print('Simulating network consisting of %d species and %d/2 reactions!' % (CRN.N, CRN.M))
with open( './results/network' + suffix + '.obj', 'wb') as network_file:
pickle.dump(CRN, network_file)
if cont:
As = np.load('./results/As' + suffix + '.npy')
CRN.A = torch.tensor(As[:,:,-1].squeeze()).cuda()
ts = np.load('./results/ts' + suffix + '.npy')
CRN.t = torch.tensor(ts[:,-1].squeeze()).cuda()
suffix = suffix + '_cont'
results = CRN.run(steps, log_every = log_every, print_debug = print_debug)
save_trajectories(results, './results/', suffix)
### Plotting
for forcing in forcings:
suffix = 'wta_demo_network_' + str(forcing)
ts = np.load('./results/ts' + suffix + '.npy')
As = np.load('./results/As' + suffix + '.npy')
for j in range(min(ts.shape[0],max_plots)):
plt.figure(figsize = (30*cm, 20*cm))
plt.plot(ts[j,:],As[:,j,:].squeeze().transpose(), label = None)
plt.savefig("./figures/Fig_1b_WTA_forcing_%d_trajectory_id_%d.svg" % (forcing, j) )