Skip to content

Commit

Permalink
PUBDEV-8547 - eif scoring history - test score_each_iteration and sco…
Browse files Browse the repository at this point in the history
…re_tree_interval in java
  • Loading branch information
valenad1 committed Oct 26, 2023
1 parent 669aca2 commit e2f32ac
Showing 1 changed file with 110 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,114 @@ public void reduce(MeanScoreTask other) {
totalMeanLength += other.totalMeanLength;
}
}

@Test
public void testScoreEachIteration() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/anomaly/single_blob.csv"));
ExtendedIsolationForestModel.ExtendedIsolationForestParameters p =
new ExtendedIsolationForestModel.ExtendedIsolationForestParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._ntrees = 5;
p._score_each_iteration = true;
p._extension_level = train.numCols() - 1;

ExtendedIsolationForest eif = new ExtendedIsolationForest(p);
ExtendedIsolationForestModel model = eif.trainModel().get();
assertNotNull(model);
Scope.track_generic(model);

LOG.info(model._output._scoring_history);

assertEquals("Number of rows is not correct", 6, model._output._scoring_history.getRowDim());
for (int treeNum = 0; treeNum <= p._ntrees; treeNum++) {
assertEquals("Tree number is not correct", treeNum, model._output._scoring_history.get(treeNum, 2));
}
} finally {
Scope.exit();
}
}

@Test
public void testScoreTreeIntervalSmoke() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/anomaly/single_blob.csv"));
ExtendedIsolationForestModel.ExtendedIsolationForestParameters p =
new ExtendedIsolationForestModel.ExtendedIsolationForestParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._ntrees = 5;
p._score_tree_interval = 2;
p._extension_level = train.numCols() - 1;

ExtendedIsolationForest eif = new ExtendedIsolationForest(p);
ExtendedIsolationForestModel model = eif.trainModel().get();
assertNotNull(model);
Scope.track_generic(model);

LOG.info(model._output._scoring_history);

assertEquals("Number of rows is not correct", 4, model._output._scoring_history.getRowDim());
assertEquals("Tree number is not correct", 0, model._output._scoring_history.get(0, 2));
assertEquals("Tree number is not correct", 2, model._output._scoring_history.get(1, 2));
assertEquals("Tree number is not correct", 4, model._output._scoring_history.get(2, 2));
assertEquals("Tree number is not correct", 5, model._output._scoring_history.get(3, 2));
} finally {
Scope.exit();
}
}

@Test
public void testScoreTreeInterval() {
try {
Scope.enter();
Frame train = Scope.track(parseTestFile("smalldata/anomaly/single_blob.csv"));
ExtendedIsolationForestModel.ExtendedIsolationForestParameters p =
new ExtendedIsolationForestModel.ExtendedIsolationForestParameters();
p._train = train._key;
p._seed = 0xDECAF;
p._ntrees = 100;
p._score_tree_interval = 30;
p._extension_level = train.numCols() - 1;

ExtendedIsolationForest eif = new ExtendedIsolationForest(p);
ExtendedIsolationForestModel model = eif.trainModel().get();
assertNotNull(model);
Scope.track_generic(model);

LOG.info(model._output._scoring_history);

p._ntrees = 30;
p._score_tree_interval = 0;
eif = new ExtendedIsolationForest(p);
ExtendedIsolationForestModel model30 = eif.trainModel().get();
assertNotNull(model30);
Scope.track_generic(model30);
ModelMetricsAnomaly modelMetricsAnomaly30 = (ModelMetricsAnomaly) model30._output._training_metrics;
p._ntrees = 60;
eif = new ExtendedIsolationForest(p);
ExtendedIsolationForestModel model60 = eif.trainModel().get();
assertNotNull(model60);
Scope.track_generic(model60);
ModelMetricsAnomaly modelMetricsAnomaly60 = (ModelMetricsAnomaly) model60._output._training_metrics;
p._ntrees = 90;
eif = new ExtendedIsolationForest(p);
ExtendedIsolationForestModel model90 = eif.trainModel().get();
assertNotNull(model90);
Scope.track_generic(model90);
ModelMetricsAnomaly modelMetricsAnomaly90 = (ModelMetricsAnomaly) model90._output._training_metrics;

assertEquals("Partial score is not correct", modelMetricsAnomaly30._mean_score, model._output._scoring_history.get(1, 3));
assertEquals("Partial score is not correct", modelMetricsAnomaly30._mean_normalized_score, model._output._scoring_history.get(1, 4));
assertEquals("Partial score is not correct", modelMetricsAnomaly60._mean_score, model._output._scoring_history.get(2, 3));
assertEquals("Partial score is not correct", modelMetricsAnomaly60._mean_normalized_score, model._output._scoring_history.get(2, 4));
assertEquals("Partial score is not correct", modelMetricsAnomaly90._mean_score, model._output._scoring_history.get(3, 3));
assertEquals("Partial score is not correct", modelMetricsAnomaly90._mean_normalized_score, model._output._scoring_history.get(3, 4));
} finally {
Scope.exit();
}
}
}

0 comments on commit e2f32ac

Please sign in to comment.