-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
156 lines (129 loc) · 5.04 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import jax
import optax
from jax import numpy as jnp
import jax.random as random
from jax.tree_util import tree_map
from solver import get_twoway_sampler
def get_mix_loss_fn(
mix, modelf, modelb,
num_steps, reduce_mean=False,
eps=1e-3, weight_type='importance'
):
reduce_op = jnp.mean if reduce_mean else \
lambda *args, **kwargs: jnp.sum(*args, **kwargs)
sampler = get_twoway_sampler(mix, num_steps)
Z = mix.importance_cum_weight(mix.tf-eps, eps)
def weight_fn(t):
if weight_type=='importance':
weight = jnp.ones_like(t) * Z
elif weight_type=='default':
weight = 1./mix.beta_schedule.beta_t(t)
else:
raise NotImplementedError(f'{weight_type} not implemented.')
return weight
def loss_fn(rng, paramsf, statesf, paramsb, statesb, x):
# Forward (prior->data) drift
predf_fn = mix.get_drift_fn(modelf, paramsf, statesf, return_state=True)
# Backward (data->prior) drift
predb_fn = mix.get_drift_fn(modelb, paramsb, statesb, return_state=True)
rng, step_rng = random.split(rng)
if 'importance' in weight_type:
t = mix.sample_importance_weighted_time(
step_rng,
(x.shape[0],),
eps
)
else:
t = random.uniform(
step_rng,
(x.shape[0],),
minval=mix.t0 + eps,
maxval=mix.tf - eps
)
rng, step_rng = random.split(rng)
x0 = mix.prior.sample(step_rng, x.shape)
rng, step_rng = random.split(rng)
xt = sampler(step_rng, x0, x, t)
# weight
weight = weight_fn(t)
# Forward model loss
predf, new_model_statef = predf_fn(xt, t, step_rng)
lossesf = predf - mix.bridge(x).drift(xt, t)
lossesf = weight * 0.5 * mix.manifold.metric.squared_norm(lossesf, xt)
lossesf = reduce_op(lossesf.reshape(lossesf.shape[0], -1), axis=-1)
# Backward model loss
predb, new_model_stateb = predb_fn(xt, mix.tf-t, step_rng)
lossesb = predb - mix.rev().bridge(x0).drift(xt, mix.tf-t)
lossesb = weight * 0.5 * mix.manifold.metric.squared_norm(lossesb, xt)
lossesb = reduce_op(lossesb.reshape(lossesb.shape[0], -1), axis=-1)
lossf, lossb = jnp.mean(lossesf), jnp.mean(lossesb)
loss = lossf + lossb
return loss, (lossf, lossb, new_model_statef, new_model_stateb)
return loss_fn
def get_ema_loss_step_fn(
loss_fn,
optimizerf,
optimizerb
):
"""Create a one-step training/evaluation function.
"""
def step_fn(carry_state, batch):
"""Running one step of training or evaluation.
This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
for faster execution.
Args:
carry_state: A tuple (JAX random state, NamedTuple containing the training state).
batch: A mini-batch of training/evaluation data.
Returns:
new_carry_state: The updated tuple of `carry_state`.
loss: The average loss value of this state.
"""
(rng, train_state) = carry_state
rng, step_rng = jax.random.split(rng)
grad_fn = jax.value_and_grad(loss_fn, argnums=(1,3), has_aux=True)
paramsf = train_state.paramsf
model_statef = train_state.model_statef
paramsb = train_state.paramsb
model_stateb = train_state.model_stateb
(loss, (lossf, lossb, new_model_statef, new_model_stateb)), grad = grad_fn(
step_rng, paramsf, model_statef, paramsb, model_stateb, batch
)
updatesf, new_opt_statef = optimizerf.update(
grad[0],
train_state.opt_statef,
paramsf
)
updatesb, new_opt_stateb = optimizerb.update(
grad[1],
train_state.opt_stateb,
paramsb
)
new_parmasf = optax.apply_updates(paramsf, updatesf)
new_parmasb = optax.apply_updates(paramsb, updatesb)
new_params_emaf = tree_map(
lambda p_ema, p: p_ema * train_state.ema_rate
+ p * (1.0 - train_state.ema_rate),
train_state.params_emaf,
new_parmasf,
)
new_params_emab = tree_map(
lambda p_ema, p: p_ema * train_state.ema_rate
+ p * (1.0 - train_state.ema_rate),
train_state.params_emab,
new_parmasb,
)
step = train_state.step + 1
new_train_state = train_state._replace(
step=step,
opt_statef=new_opt_statef,
model_statef=new_model_statef,
paramsf=new_parmasf,
params_emaf=new_params_emaf,
opt_stateb=new_opt_stateb,
model_stateb=new_model_stateb,
paramsb=new_parmasb,
params_emab=new_params_emab,
)
new_carry_state = (rng, new_train_state)
return new_carry_state, lossf, lossb
return step_fn