Skip to content

Commit

Permalink
GH-7118 - eif scoring and scoring history (#6037)
Browse files Browse the repository at this point in the history
* eif - scoring of train dataset is working and its ready to test

* eif - test final scoring output in java

* eif scoring history - implement scoring history and java API for _score_tree_interval

* eif scoring history - always add final scoring to the history

* eif scoring history - test score_each_iteration and score_tree_interval in java

* add posibility to disable training metrics

* test that metrics are empty
  • Loading branch information
valenad1 authored Nov 20, 2023
1 parent 008e9a4 commit 8d9304b
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import hex.tree.isoforextended.isolationtree.CompressedIsolationTree;
import hex.tree.isoforextended.isolationtree.IsolationTree;
import hex.tree.isoforextended.isolationtree.IsolationTreeStats;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.DKV;
import water.H2O;
import water.Job;
Expand Down Expand Up @@ -154,6 +158,10 @@ public void computeImpl() {

private void buildIsolationTreeEnsemble() {
_model._output._iTreeKeys = new Key[_parms._ntrees];
_model._output._scored_train = new ScoreKeeper[_parms._ntrees + 1];
_model._output._scored_train[0] = new ScoreKeeper();
_model._output._training_time_ms = new long[_parms._ntrees + 1];
_model._output._training_time_ms[0] = System.currentTimeMillis();

int heightLimit = (int) Math.ceil(MathUtils.log2(_parms._sample_size));

Expand All @@ -171,9 +179,22 @@ private void buildIsolationTreeEnsemble() {
DKV.put(compressedIsolationTree);
_job.update(1);
_model.update(_job);
LOG.info((tid + 1) + ". tree was built in " + timer.toString());
_model._output._training_time_ms[tid + 1] = System.currentTimeMillis();
LOG.info((tid + 1) + ". tree was built in " + timer);
isolationTreeStats.updateBy(isolationTree);

boolean manualInterval = _parms._score_tree_interval > 0 && (tid +1) % _parms._score_tree_interval == 0;
boolean finalScoring = _parms._ntrees == (tid + 1);

_model._output._scored_train[tid + 1] = new ScoreKeeper();
if ((_parms._score_each_iteration || manualInterval || finalScoring) && !_parms._disable_training_metrics) {
ModelMetrics.MetricBuilder metricsBuilder = new ScoreExtendedIsolationForestTask(_model).doAll(_train).getMetricsBuilder();
ModelMetrics modelMetrics = metricsBuilder.makeModelMetrics(_model, _parms.train(), null, null);
_model._output._training_metrics = modelMetrics;
_model._output._scored_train[tid + 1].fillFrom(modelMetrics);
}
}
_model._output._scoring_history = _parms._disable_training_metrics ? null : createScoringHistoryTable();
}
}

Expand Down Expand Up @@ -237,4 +258,50 @@ public TwoDimTable createModelSummaryTable() {
return table;
}

protected TwoDimTable createScoringHistoryTable() {
List<String> colHeaders = new ArrayList<>();
List<String> colTypes = new ArrayList<>();
List<String> colFormat = new ArrayList<>();
colHeaders.add("Timestamp"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Duration"); colTypes.add("string"); colFormat.add("%s");
colHeaders.add("Number of Trees"); colTypes.add("long"); colFormat.add("%d");
colHeaders.add("Mean Tree Path Length"); colTypes.add("double"); colFormat.add("%.5f");
colHeaders.add("Mean Anomaly Score"); colTypes.add("double"); colFormat.add("%.5f");
if (_parms._custom_metric_func != null) {
colHeaders.add("Training Custom"); colTypes.add("double"); colFormat.add("%.5f");
}

ScoreKeeper[] sks = _model._output._scored_train;

int rows = 0;
for (int i = 0; i < sks.length; i++) {
if (i != 0 && Double.isNaN(sks[i]._anomaly_score)) continue;
rows++;
}
TwoDimTable table = new TwoDimTable(
"Scoring History", null,
new String[rows],
colHeaders.toArray(new String[0]),
colTypes.toArray(new String[0]),
colFormat.toArray(new String[0]),
"");
int row = 0;
for( int i = 0; i<sks.length; i++ ) {
if (i != 0 && Double.isNaN(sks[i]._anomaly_score)) continue;
int col = 0;
DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss");
table.set(row, col++, fmt.print(_model._output._training_time_ms[i]));
table.set(row, col++, PrettyPrint.msecs(_model._output._training_time_ms[i] - _job.start_time(), true));
table.set(row, col++, i);
ScoreKeeper st = sks[i];
table.set(row, col++, st._anomaly_score);
table.set(row, col++, st._anomaly_score_normalized);
if (_parms._custom_metric_func != null) {
table.set(row, col++, st._custom_metric);
}
assert col == colHeaders.size();
row++;
}
return table;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import hex.tree.isofor.ModelMetricsAnomaly;
import hex.tree.isoforextended.isolationtree.CompressedIsolationTree;
import org.apache.log4j.Logger;
import water.*;
import water.fvec.Frame;

import static hex.genmodel.algos.isoforextended.ExtendedIsolationForestMojoModel.anomalyScore;
import static hex.genmodel.algos.isoforextended.ExtendedIsolationForestMojoModel.averagePathLengthOfUnsuccessfulSearch;

/**
*
Expand Down Expand Up @@ -46,13 +46,16 @@ protected double[] score0(double[] data, double[] preds) {
assert _output._iTreeKeys != null : "Output has no trees, check if trees are properly set to the output.";
// compute score for given point
double pathLength = 0;
int numberOfTrees = 0;
for (Key<CompressedIsolationTree> iTreeKey : _output._iTreeKeys) {
if (iTreeKey == null) continue;
numberOfTrees++;
CompressedIsolationTree iTree = DKV.getGet(iTreeKey);
double iTreeScore = iTree.computePathLength(data);
pathLength += iTreeScore;
LOG.trace("iTreeScore " + iTreeScore);
}
pathLength = pathLength / _output._ntrees;
pathLength = pathLength / numberOfTrees;
LOG.trace("Path length " + pathLength);
double anomalyScore = anomalyScore(pathLength, _output._sample_size);
LOG.trace("Anomaly score " + anomalyScore);
Expand Down Expand Up @@ -98,18 +101,32 @@ public long progressUnits() {
*/
public int _sample_size;

/**
* Score every so many trees (no matter what)
*/
public int _score_tree_interval;

/**
* Disable calculating training metrics (expensive on large datasets).
*/
public boolean _disable_training_metrics;

public ExtendedIsolationForestParameters() {
super();
_ntrees = 100;
_sample_size = 256;
_extension_level = 0;
_score_tree_interval = 0;
_disable_training_metrics = true;
}
}

public static class ExtendedIsolationForestOutput extends Model.Output {

public int _ntrees;
public long _sample_size;
public ScoreKeeper[] _scored_train;
public long[] _training_time_ms;

public Key<CompressedIsolationTree>[] _iTreeKeys;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package hex.tree.isoforextended;

import hex.tree.isofor.ModelMetricsAnomaly;
import water.MRTask;
import water.fvec.Chunk;

class ScoreExtendedIsolationForestTask extends MRTask<ScoreExtendedIsolationForestTask> {
private ExtendedIsolationForestModel _model;

// output
private ModelMetricsAnomaly.MetricBuilderAnomaly _metricsBuilder;

public ScoreExtendedIsolationForestTask(ExtendedIsolationForestModel _model) {
this._model = _model;
}

@Override
public void map(Chunk[] cs) {
_metricsBuilder = (ModelMetricsAnomaly.MetricBuilderAnomaly) _model.makeMetricBuilder(null);
double [] preds = new double[2];
double [] tmp = new double[cs.length];
for (int row = 0; row < cs[0]._len; row++) {
preds = _model.score0(cs, 0, row, tmp, preds);
_metricsBuilder.perRow(preds, null, _model);
}
}

@Override
public void reduce(ScoreExtendedIsolationForestTask other) {
_metricsBuilder.reduce(other._metricsBuilder);
}

public ModelMetricsAnomaly.MetricBuilderAnomaly getMetricsBuilder() {
return _metricsBuilder;
}
}
Loading

0 comments on commit 8d9304b

Please sign in to comment.