From 633ec1a33659f37b328135a6c4065d0ebb823fab Mon Sep 17 00:00:00 2001 From: Ben Sully Date: Mon, 23 Dec 2024 10:23:46 +0000 Subject: [PATCH] fix: fix filtering of NaNs in Prophet preprocessing (#219) --- crates/augurs-prophet/src/data.rs | 62 +++++++++++++++++++++++ crates/augurs-prophet/src/forecaster.rs | 2 +- crates/augurs-prophet/src/prophet.rs | 12 +++++ crates/augurs-prophet/src/prophet/prep.rs | 9 +--- 4 files changed, 77 insertions(+), 8 deletions(-) diff --git a/crates/augurs-prophet/src/data.rs b/crates/augurs-prophet/src/data.rs index b789275e..34933d82 100644 --- a/crates/augurs-prophet/src/data.rs +++ b/crates/augurs-prophet/src/data.rs @@ -132,6 +132,52 @@ impl TrainingData { Ok(self) } + /// Remove any NaN values from the `y` column, and the corresponding values + /// in the other columns. + /// + /// This handles updating all columns and `n` appropriately. + /// + /// NaN values in other columns are retained. + pub(crate) fn filter_nans(mut self) -> Self { + let mut n = self.n; + let mut keep = vec![true; self.n]; + self.y = self + .y + .into_iter() + .zip(keep.iter_mut()) + .filter_map(|(y, keep)| { + if y.is_nan() { + *keep = false; + n -= 1; + None + } else { + Some(y) + } + }) + .collect(); + + fn retain(v: &mut Vec, keep: &[bool]) { + let mut iter = keep.iter(); + v.retain(|_| *iter.next().unwrap()); + } + + self.n = n; + retain(&mut self.ds, &keep); + if let Some(cap) = self.cap.as_mut() { + retain(cap, &keep); + } + if let Some(floor) = self.floor.as_mut() { + retain(floor, &keep); + } + for v in self.x.values_mut() { + retain(v, &keep); + } + for v in self.seasonality_conditions.values_mut() { + retain(v, &keep); + } + self + } + #[cfg(test)] pub(crate) fn head(mut self, n: usize) -> Self { self.n = n; @@ -298,3 +344,19 @@ impl PredictionData { Ok(self) } } + +#[cfg(test)] +mod test { + use crate::testdata::daily_univariate_ts; + + #[test] + fn filter_nans() { + let mut data = daily_univariate_ts(); + let expected_len = data.n - 1; + data.y[10] = f64::NAN; + let data = data.filter_nans(); + assert_eq!(data.n, expected_len); + assert_eq!(data.y.len(), expected_len); + assert_eq!(data.ds.len(), expected_len); + } +} diff --git a/crates/augurs-prophet/src/forecaster.rs b/crates/augurs-prophet/src/forecaster.rs index bd176bae..1f940045 100644 --- a/crates/augurs-prophet/src/forecaster.rs +++ b/crates/augurs-prophet/src/forecaster.rs @@ -151,7 +151,7 @@ impl Predict for FittedProphetForecaster { } } -#[cfg(test)] +#[cfg(all(test, feature = "wasmstan"))] mod test { use augurs_core::{Fit, Predict}; diff --git a/crates/augurs-prophet/src/prophet.rs b/crates/augurs-prophet/src/prophet.rs index 62f928ed..aeea06e3 100644 --- a/crates/augurs-prophet/src/prophet.rs +++ b/crates/augurs-prophet/src/prophet.rs @@ -863,4 +863,16 @@ mod test_fit { &[0.781831, 0.623490, 0.974928, -0.222521, 0.433884, -0.900969], ); } + + // Regression test for https://github.com/grafana/augurs/issues/209. + #[test] + fn fit_with_nans() { + let test_days = 30; + let (mut train, _) = train_test_splitn(daily_univariate_ts(), test_days); + train.y[10] = f64::NAN; + let opt = MockOptimizer::new(); + let mut prophet = Prophet::new(Default::default(), opt); + // Should not panic. + prophet.fit(train.clone(), Default::default()).unwrap(); + } } diff --git a/crates/augurs-prophet/src/prophet/prep.rs b/crates/augurs-prophet/src/prophet/prep.rs index 8c1565fa..5b0da263 100644 --- a/crates/augurs-prophet/src/prophet/prep.rs +++ b/crates/augurs-prophet/src/prophet/prep.rs @@ -194,7 +194,7 @@ pub(super) struct Features { } impl Prophet { - pub(super) fn preprocess(&mut self, mut data: TrainingData) -> Result { + pub(super) fn preprocess(&mut self, data: TrainingData) -> Result { let n = data.ds.len(); if n != data.y.len() { return Err(Error::MismatchedLengths { @@ -207,12 +207,7 @@ impl Prophet { if n < 2 { return Err(Error::NotEnoughData); } - (data.ds, data.y) = data - .ds - .into_iter() - .zip(data.y) - .filter(|(_, y)| !y.is_nan()) - .unzip(); + let data = data.filter_nans(); let mut history_dates = data.ds.clone(); history_dates.sort_unstable();