Skip to content

Commit

Permalink
Fix Elu gradient NaN on large input (#2328)
Browse files Browse the repository at this point in the history
* Fix Elu gradient NaN on large input

* Reuse previously computed exp in Elu
  • Loading branch information
agerasev authored Jul 16, 2024
1 parent 30cdd76 commit 6a4741b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,8 @@ impl Tensor {
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
// node == alpha * (e^x - 1) for x <= 0, reuse it
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Expand Down

0 comments on commit 6a4741b

Please sign in to comment.