Skip to content

Commit

Permalink
Fix log_sum_exp to handle large positive/negative inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu committed Jul 30, 2024
1 parent 24d54d0 commit 79e5cb5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
16 changes: 13 additions & 3 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2440,9 +2440,19 @@ impl Tensor {

/// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
if sum_dims.is_empty() {
return Ok(self.clone());
}
let max = sum_dims[1..]
.iter()
.try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
max.max_keepdim(dim)
})?;
let exp = self.broadcast_sub(&max)?.exp()?;
let sum = exp.sum(sum_dims.clone())?;

sum.log()? + max.squeeze_dims(&sum_dims)
}

/// Pointwise pow operation.
Expand Down
24 changes: 21 additions & 3 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {

#[test]
fn log_sum_exp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let input = Tensor::new(
&[
[[1f64, 2., 3.], [4., 5., 6.]],
[[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
],
&Device::Cpu,
)?;

let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
assert_eq!(output.dims(), expected.dims());
assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;

assert_eq!(
input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
[1000.0, 999.0, 1001.0]
);
assert_eq!(
input.log_sum_exp(())?.to_vec3::<f64>()?,
input.to_vec3::<f64>()?
);

Ok(())
}

Expand Down

0 comments on commit 79e5cb5

Please sign in to comment.