You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am a little confused about how the parameters experts.w1 and experts.w2 are updated. The top1 operation is non-differentiable and therefore the gradients of these two parameters would be None. To confirm i even ran the following:
moe = MoE(
dim = 512,
num_experts = 16, # increase the experts (# parameters) of your model without increasing computation
hidden_dim = 512 * 4, # size of hidden dimension in each expert, defaults to 4 * dimension
activation = nn.LeakyReLU, # use your preferred activation, will default to GELU
second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
second_policy_eval = 'random', # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25, # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
capacity_factor_eval = 2., # capacity_factor_* should be set to a value >=1
loss_coef = 1e-2 # multiplier on the auxiliary expert balancing auxiliary loss
).cuda()
inputs = torch.randn(4, 1024, 512).cuda()
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
aux_loss.backward()
for name, param in moe.named_parameters():
if param.grad is None:
print(name)
which gave the following output:
experts.w1
experts.w2
It would be really helpful if you could clarify my understanding. Thanks
The text was updated successfully, but these errors were encountered:
try it on cpu ? I can work
top1 operation is non-differentiable, but the balance loss is based on logits of gating distribution and count num of tokens per expert, so actually the grad of weight should not be None
Hi @lucidrains
I am a little confused about how the parameters experts.w1 and experts.w2 are updated. The top1 operation is non-differentiable and therefore the gradients of these two parameters would be None. To confirm i even ran the following:
which gave the following output:
It would be really helpful if you could clarify my understanding. Thanks
The text was updated successfully, but these errors were encountered: