Skip to content

Commit

Permalink
fix: fix filtering of NaNs in Prophet preprocessing (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
sd2k authored Dec 23, 2024
1 parent 8ed69b7 commit 633ec1a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 8 deletions.
62 changes: 62 additions & 0 deletions crates/augurs-prophet/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,52 @@ impl TrainingData {
Ok(self)
}

/// Remove any NaN values from the `y` column, and the corresponding values
/// in the other columns.
///
/// This handles updating all columns and `n` appropriately.
///
/// NaN values in other columns are retained.
pub(crate) fn filter_nans(mut self) -> Self {
let mut n = self.n;
let mut keep = vec![true; self.n];
self.y = self
.y
.into_iter()
.zip(keep.iter_mut())
.filter_map(|(y, keep)| {
if y.is_nan() {
*keep = false;
n -= 1;
None
} else {
Some(y)
}
})
.collect();

fn retain<T>(v: &mut Vec<T>, keep: &[bool]) {
let mut iter = keep.iter();
v.retain(|_| *iter.next().unwrap());
}

self.n = n;
retain(&mut self.ds, &keep);
if let Some(cap) = self.cap.as_mut() {
retain(cap, &keep);
}
if let Some(floor) = self.floor.as_mut() {
retain(floor, &keep);
}
for v in self.x.values_mut() {
retain(v, &keep);
}
for v in self.seasonality_conditions.values_mut() {
retain(v, &keep);
}
self
}

#[cfg(test)]
pub(crate) fn head(mut self, n: usize) -> Self {
self.n = n;
Expand Down Expand Up @@ -298,3 +344,19 @@ impl PredictionData {
Ok(self)
}
}

#[cfg(test)]
mod test {
use crate::testdata::daily_univariate_ts;

#[test]
fn filter_nans() {
let mut data = daily_univariate_ts();
let expected_len = data.n - 1;
data.y[10] = f64::NAN;
let data = data.filter_nans();
assert_eq!(data.n, expected_len);
assert_eq!(data.y.len(), expected_len);
assert_eq!(data.ds.len(), expected_len);
}
}
2 changes: 1 addition & 1 deletion crates/augurs-prophet/src/forecaster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl Predict for FittedProphetForecaster {
}
}

#[cfg(test)]
#[cfg(all(test, feature = "wasmstan"))]
mod test {

use augurs_core::{Fit, Predict};
Expand Down
12 changes: 12 additions & 0 deletions crates/augurs-prophet/src/prophet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,4 +863,16 @@ mod test_fit {
&[0.781831, 0.623490, 0.974928, -0.222521, 0.433884, -0.900969],
);
}

// Regression test for https://github.com/grafana/augurs/issues/209.
#[test]
fn fit_with_nans() {
let test_days = 30;
let (mut train, _) = train_test_splitn(daily_univariate_ts(), test_days);
train.y[10] = f64::NAN;
let opt = MockOptimizer::new();
let mut prophet = Prophet::new(Default::default(), opt);
// Should not panic.
prophet.fit(train.clone(), Default::default()).unwrap();
}
}
9 changes: 2 additions & 7 deletions crates/augurs-prophet/src/prophet/prep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pub(super) struct Features {
}

impl<O> Prophet<O> {
pub(super) fn preprocess(&mut self, mut data: TrainingData) -> Result<Preprocessed, Error> {
pub(super) fn preprocess(&mut self, data: TrainingData) -> Result<Preprocessed, Error> {
let n = data.ds.len();
if n != data.y.len() {
return Err(Error::MismatchedLengths {
Expand All @@ -207,12 +207,7 @@ impl<O> Prophet<O> {
if n < 2 {
return Err(Error::NotEnoughData);
}
(data.ds, data.y) = data
.ds
.into_iter()
.zip(data.y)
.filter(|(_, y)| !y.is_nan())
.unzip();
let data = data.filter_nans();

let mut history_dates = data.ds.clone();
history_dates.sort_unstable();
Expand Down

0 comments on commit 633ec1a

Please sign in to comment.