You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be great if entmax worked with torch.float16 and torch.bfloat16. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for bothtorch.float16 and torch.bfloat16 (don't believe the propaganda that says that bf16 is a drop-in solution for float32).
Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.
a = torch.zeros(128, device="cuda").fill_(-5) # torch.float32
a[0] = 0
a -= 1000
With alpha=1.5, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with both entmax.entmax15 and entmax.entmax_bisect.
As it turns out, we can use the same solution for both torch.float16 and torch.bfloat16. This doesn't seem obvious because they offer different tradeoffs compared to full-precision floats: fp16 keeps the mantissa at the cost of reduced range, bf16 keeps the range at cost of reduced mantissa. But strangely enough we can, using the classic softmax stability trick of subtracting the largest logit from the vector. This makes intuitive sense for torch.float16 because of its reduced range, but what about for torch.bfloat16?
It turns out that bfloat16 has weird problems far away from zero:
x = torch.tensor(0, dtype=torch.bfloat16, device="cuda")
x == (x - 1) # False, of course
y = torch.tensor(-500, dtype=torch.bfloat16, device="cuda")
y == (y - 1) # True?!?
So that's why the softmax stability trick works. Bringing it back to our earlier examples:
b = a.to(torch.float16)
entmax.entmax_bisect(b - b.max(), alpha=1.5) # one-hot!
c = a.to(torch.bfloat16)
entmax.entmax_bisect(c - c.max(), alpha=1.5) # one-hot!
It would be great if entmax worked with
torch.float16
andtorch.bfloat16
. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for bothtorch.float16
andtorch.bfloat16
(don't believe the propaganda that says that bf16 is a drop-in solution for float32).Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.
With
alpha=1.5
, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with bothentmax.entmax15
andentmax.entmax_bisect
.Ok, great. But what happens if we use
torch.float16
?and what about
torch.bfloat16
?Well that's not good! (solution after this commercial break)
The text was updated successfully, but these errors were encountered: