diff --git a/h2o-algos/src/test/java/hex/tree/isoforextended/ExtendedIsolationForestTest.java b/h2o-algos/src/test/java/hex/tree/isoforextended/ExtendedIsolationForestTest.java index c737b669afe2..49798423903a 100644 --- a/h2o-algos/src/test/java/hex/tree/isoforextended/ExtendedIsolationForestTest.java +++ b/h2o-algos/src/test/java/hex/tree/isoforextended/ExtendedIsolationForestTest.java @@ -382,4 +382,114 @@ public void reduce(MeanScoreTask other) { totalMeanLength += other.totalMeanLength; } } + + @Test + public void testScoreEachIteration() { + try { + Scope.enter(); + Frame train = Scope.track(parseTestFile("smalldata/anomaly/single_blob.csv")); + ExtendedIsolationForestModel.ExtendedIsolationForestParameters p = + new ExtendedIsolationForestModel.ExtendedIsolationForestParameters(); + p._train = train._key; + p._seed = 0xDECAF; + p._ntrees = 5; + p._score_each_iteration = true; + p._extension_level = train.numCols() - 1; + + ExtendedIsolationForest eif = new ExtendedIsolationForest(p); + ExtendedIsolationForestModel model = eif.trainModel().get(); + assertNotNull(model); + Scope.track_generic(model); + + LOG.info(model._output._scoring_history); + + assertEquals("Number of rows is not correct", 6, model._output._scoring_history.getRowDim()); + for (int treeNum = 0; treeNum <= p._ntrees; treeNum++) { + assertEquals("Tree number is not correct", treeNum, model._output._scoring_history.get(treeNum, 2)); + } + } finally { + Scope.exit(); + } + } + + @Test + public void testScoreTreeIntervalSmoke() { + try { + Scope.enter(); + Frame train = Scope.track(parseTestFile("smalldata/anomaly/single_blob.csv")); + ExtendedIsolationForestModel.ExtendedIsolationForestParameters p = + new ExtendedIsolationForestModel.ExtendedIsolationForestParameters(); + p._train = train._key; + p._seed = 0xDECAF; + p._ntrees = 5; + p._score_tree_interval = 2; + p._extension_level = train.numCols() - 1; + + ExtendedIsolationForest eif = new ExtendedIsolationForest(p); + ExtendedIsolationForestModel model = eif.trainModel().get(); + assertNotNull(model); + Scope.track_generic(model); + + LOG.info(model._output._scoring_history); + + assertEquals("Number of rows is not correct", 4, model._output._scoring_history.getRowDim()); + assertEquals("Tree number is not correct", 0, model._output._scoring_history.get(0, 2)); + assertEquals("Tree number is not correct", 2, model._output._scoring_history.get(1, 2)); + assertEquals("Tree number is not correct", 4, model._output._scoring_history.get(2, 2)); + assertEquals("Tree number is not correct", 5, model._output._scoring_history.get(3, 2)); + } finally { + Scope.exit(); + } + } + + @Test + public void testScoreTreeInterval() { + try { + Scope.enter(); + Frame train = Scope.track(parseTestFile("smalldata/anomaly/single_blob.csv")); + ExtendedIsolationForestModel.ExtendedIsolationForestParameters p = + new ExtendedIsolationForestModel.ExtendedIsolationForestParameters(); + p._train = train._key; + p._seed = 0xDECAF; + p._ntrees = 100; + p._score_tree_interval = 30; + p._extension_level = train.numCols() - 1; + + ExtendedIsolationForest eif = new ExtendedIsolationForest(p); + ExtendedIsolationForestModel model = eif.trainModel().get(); + assertNotNull(model); + Scope.track_generic(model); + + LOG.info(model._output._scoring_history); + + p._ntrees = 30; + p._score_tree_interval = 0; + eif = new ExtendedIsolationForest(p); + ExtendedIsolationForestModel model30 = eif.trainModel().get(); + assertNotNull(model30); + Scope.track_generic(model30); + ModelMetricsAnomaly modelMetricsAnomaly30 = (ModelMetricsAnomaly) model30._output._training_metrics; + p._ntrees = 60; + eif = new ExtendedIsolationForest(p); + ExtendedIsolationForestModel model60 = eif.trainModel().get(); + assertNotNull(model60); + Scope.track_generic(model60); + ModelMetricsAnomaly modelMetricsAnomaly60 = (ModelMetricsAnomaly) model60._output._training_metrics; + p._ntrees = 90; + eif = new ExtendedIsolationForest(p); + ExtendedIsolationForestModel model90 = eif.trainModel().get(); + assertNotNull(model90); + Scope.track_generic(model90); + ModelMetricsAnomaly modelMetricsAnomaly90 = (ModelMetricsAnomaly) model90._output._training_metrics; + + assertEquals("Partial score is not correct", modelMetricsAnomaly30._mean_score, model._output._scoring_history.get(1, 3)); + assertEquals("Partial score is not correct", modelMetricsAnomaly30._mean_normalized_score, model._output._scoring_history.get(1, 4)); + assertEquals("Partial score is not correct", modelMetricsAnomaly60._mean_score, model._output._scoring_history.get(2, 3)); + assertEquals("Partial score is not correct", modelMetricsAnomaly60._mean_normalized_score, model._output._scoring_history.get(2, 4)); + assertEquals("Partial score is not correct", modelMetricsAnomaly90._mean_score, model._output._scoring_history.get(3, 3)); + assertEquals("Partial score is not correct", modelMetricsAnomaly90._mean_normalized_score, model._output._scoring_history.get(3, 4)); + } finally { + Scope.exit(); + } + } }