Skip to content

Commit

Permalink
PUBDEV-8547 - eif scoring history - implement scoring history and jav…
Browse files Browse the repository at this point in the history
…a API for _score_tree_interval
  • Loading branch information
valenad1 committed Oct 30, 2023
1 parent 312ae87 commit 1d24726
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import hex.tree.isoforextended.isolationtree.CompressedIsolationTree;
import hex.tree.isoforextended.isolationtree.IsolationTree;
import hex.tree.isoforextended.isolationtree.IsolationTreeStats;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.DKV;
import water.H2O;
import water.Job;
Expand Down Expand Up @@ -154,6 +158,10 @@ public void computeImpl() {

private void buildIsolationTreeEnsemble() {
_model._output._iTreeKeys = new Key[_parms._ntrees];
_model._output._scored_train = new ScoreKeeper[_parms._ntrees + 1];
_model._output._scored_train[0] = new ScoreKeeper();
_model._output._training_time_ms = new long[_parms._ntrees + 1];
_model._output._training_time_ms[0] = System.currentTimeMillis();

int heightLimit = (int) Math.ceil(MathUtils.log2(_parms._sample_size));

Expand All @@ -171,10 +179,20 @@ private void buildIsolationTreeEnsemble() {
DKV.put(compressedIsolationTree);
_job.update(1);
_model.update(_job);
LOG.info((tid + 1) + ". tree was built in " + timer.toString());
_model._output._training_time_ms[tid + 1] = System.currentTimeMillis();
LOG.info((tid + 1) + ". tree was built in " + timer);
isolationTreeStats.updateBy(isolationTree);

boolean manualInterval = _parms._score_tree_interval > 0 && (tid +1) % _parms._score_tree_interval == 0;

_model._output._scored_train[tid + 1] = new ScoreKeeper();
if (_parms._score_each_iteration || manualInterval) {
ModelMetrics.MetricBuilder mb = new ScoreExtendedIsolationForestTask(_model).doAll(_train).getMetricsBuilder();
_model._output._scored_train[tid + 1].fillFrom(mb.makeModelMetrics(_model, _parms.train(), null, null));
}
}
_model._output._training_metrics = new ScoreExtendedIsolationForestTask(_model).doAll(_train).getMetricsBuilder().makeModelMetrics(_model, _parms.train(), null, null);
_model._output._scoring_history = createScoringHistoryTable();
}
}

Expand Down Expand Up @@ -238,4 +256,50 @@ public TwoDimTable createModelSummaryTable() {
return table;
}

protected TwoDimTable createScoringHistoryTable() {
List<String> colHeaders = new ArrayList<>();
List<String> colTypes = new ArrayList<>();
List<String> colFormat = new ArrayList<>();
colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Number of Trees"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Mean Tree Path Length"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Mean Anomaly Score"); colTypes.add("double"); colFormat.add("%.5f");
if (_parms._custom_metric_func != null) {
colHeaders.add("Training Custom"); colTypes.add("double"); colFormat.add("%.5f");
}

ScoreKeeper[] sks = _model._output._scored_train;

int rows = 0;
for (int i = 0; i < sks.length; i++) {
if (i != 0 && Double.isNaN(sks[i]._anomaly_score)) continue;
rows++;
}
TwoDimTable table = new TwoDimTable(
"Scoring History", null,
new String[rows],
colHeaders.toArray(new String[0]),
colTypes.toArray(new String[0]),
colFormat.toArray(new String[0]),
"");
int row = 0;
for( int i = 0; i<sks.length; i++ ) {
if (i != 0 && Double.isNaN(sks[i]._anomaly_score)) continue;
int col = 0;
DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
table.set(row, col++, fmt.print(_model._output._training_time_ms[i]));
table.set(row, col++, PrettyPrint.msecs(_model._output._training_time_ms[i] - _job.start_time(), true));
table.set(row, col++, i);
ScoreKeeper st = sks[i];
table.set(row, col++, st._anomaly_score);
table.set(row, col++, st._anomaly_score_normalized);
if (_parms._custom_metric_func != null) {
table.set(row, col++, st._custom_metric);
}
assert col == colHeaders.size();
row++;
}
return table;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import hex.tree.isofor.ModelMetricsAnomaly;
import hex.tree.isoforextended.isolationtree.CompressedIsolationTree;
import org.apache.log4j.Logger;
import water.*;
import water.fvec.Frame;

import static hex.genmodel.algos.isoforextended.ExtendedIsolationForestMojoModel.anomalyScore;
import static hex.genmodel.algos.isoforextended.ExtendedIsolationForestMojoModel.averagePathLengthOfUnsuccessfulSearch;

/**
*
Expand Down Expand Up @@ -101,6 +101,11 @@ public long progressUnits() {
*/
public int _sample_size;

/**
* Score every so many trees (no matter what)
*/
public int _score_tree_interval = 0;

public ExtendedIsolationForestParameters() {
super();
_ntrees = 100;
Expand All @@ -113,6 +118,8 @@ public static class ExtendedIsolationForestOutput extends Model.Output {

public int _ntrees;
public long _sample_size;
public ScoreKeeper[] _scored_train;
public long[] _training_time_ms;

public Key<CompressedIsolationTree>[] _iTreeKeys;

Expand Down

0 comments on commit 1d24726

Please sign in to comment.