From 79e5cb52575808ce6cb3b635bfea9ab514206b6b Mon Sep 17 00:00:00 2001 From: Yun-Jhong Wu Date: Mon, 29 Jul 2024 20:24:29 -0500 Subject: [PATCH] Fix log_sum_exp to handle large positive/negative inputs --- candle-core/src/tensor.rs | 16 +++++++++++++--- candle-core/tests/tensor_tests.rs | 24 +++++++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 82532f204f..e8b026057e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2440,9 +2440,19 @@ impl Tensor { /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp(&self, sum_dims: D) -> Result { - let exp = self.exp()?; - let sum = exp.sum(sum_dims)?; - sum.log() + let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; + if sum_dims.is_empty() { + return Ok(self.clone()); + } + let max = sum_dims[1..] + .iter() + .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| { + max.max_keepdim(dim) + })?; + let exp = self.broadcast_sub(&max)?.exp()?; + let sum = exp.sum(sum_dims.clone())?; + + sum.log()? + max.squeeze_dims(&sum_dims) } /// Pointwise pow operation. diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca148..567b49f1db 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { #[test] fn log_sum_exp() -> Result<()> { - let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let input = Tensor::new( + &[ + [[1f64, 2., 3.], [4., 5., 6.]], + [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]], + ], + &Device::Cpu, + )?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. - let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; - assert_close(&output, &expected, 0.00001)?; + let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?; + assert_eq!(output.dims(), expected.dims()); + assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?; + + assert_eq!( + input.log_sum_exp((0, 1))?.to_vec1::()?, + [1000.0, 999.0, 1001.0] + ); + assert_eq!( + input.log_sum_exp(())?.to_vec3::()?, + input.to_vec3::()? + ); + Ok(()) }