Skip to content

Commit

Permalink
added example of multiplier transformer in PipelineTest to show how c…
Browse files Browse the repository at this point in the history
…olumns transformations can be easily implemented and applied declaratively
  • Loading branch information
sebhrusen committed Feb 6, 2024
1 parent 009c161 commit 50c80df
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
34 changes: 31 additions & 3 deletions h2o-core/src/test/java/hex/pipeline/DataTransformerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.fvec.*;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.runner.CloudSize;
Expand Down Expand Up @@ -44,6 +43,35 @@ public void test_transform() {
}
}

public static class MultiplyNumericColumnTransformer extends DataTransformer<MultiplyNumericColumnTransformer> {

private final String colName;

private final int multiplier;

public MultiplyNumericColumnTransformer(String colName, int multiplier) {
this.colName = colName;
this.multiplier = multiplier;
}

@Override
protected Frame doTransform(Frame fr, FrameType type, PipelineContext context) {
Frame tr = new Frame(fr);
final int colIdx = tr.find(colName);
Vec col = tr.vec(colIdx);
assert col.isNumeric();
Vec multCol = new MRTask() {
@Override
public void map(Chunk[] cs, NewChunk[] ncs) {
for (int i = 0; i < cs[0]._len; i++)
ncs[0].addNum(multiplier * (cs[0].atd(i)));
}
}.doAll(Vec.T_NUM, new Frame(col)).outputFrame().vec(0);
tr.replace(colIdx, multCol);
return tr;
}
}


public static class AddRandomColumnTransformer extends DataTransformer<AddRandomColumnTransformer> {

Expand Down
18 changes: 13 additions & 5 deletions h2o-core/src/test/java/hex/pipeline/PipelineTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

import hex.Model;
import hex.ModelBuilder;
import hex.pipeline.DataTransformerTest.AddDummyCVColumnTransformer;
import hex.pipeline.DataTransformerTest.AddRandomColumnTransformer;
import hex.pipeline.DataTransformerTest.FrameCheckerAsTransformer;
import hex.pipeline.DataTransformerTest.FrameTrackerAsTransformer;
import hex.pipeline.DataTransformerTest.*;
import hex.pipeline.PipelineModel.PipelineOutput;
import hex.pipeline.PipelineModel.PipelineParameters;
import org.junit.Rule;
Expand Down Expand Up @@ -56,6 +53,7 @@ public void test_simple_transformation_pipeline() {
PipelineParameters pparams = new PipelineParameters();
FrameTrackerAsTransformer tracker = new FrameTrackerAsTransformer();
pparams._transformers = new DataTransformer[] {
new MultiplyNumericColumnTransformer("two", 5).id("mult_5"),
new AddRandomColumnTransformer("foo").id("add_foo"),
new AddRandomColumnTransformer("bar").id("add_bar"),
tracker.id("tracker")
Expand All @@ -67,6 +65,9 @@ public void test_simple_transformation_pipeline() {
.withDataForCol(1, ard(3, 2, 1))
.withDataForCol(2, ar("yes", "no", "yes"))
.build());

Vec notMult = fr.vec(1).makeCopy();
Vec mult = fr.vec(1).makeZero(); mult.set(0, 3*5); mult.set(1, 2*5); mult.set(2, 1*5); // a 5-mult vec of column "two" hand-made for test assertions.

pparams._train = fr._key;

Expand All @@ -78,9 +79,10 @@ public void test_simple_transformation_pipeline() {
assertNotNull(output);
assertNull(output._estimator);
assertNotNull(output._transformers);
assertEquals(3, output._transformers.length);
assertEquals(4, output._transformers.length);
assertEquals(0, tracker.transformations.size());
checkFrameState(fr);
assertVecEquals(notMult, fr.vec(1), 0);

Frame scored = Scope.track(pmodel.score(fr));
assertNotNull(scored);
Expand All @@ -89,6 +91,8 @@ public void test_simple_transformation_pipeline() {
assertArrayEquals(new String[] {"one", "two", "target", "foo", "bar"}, scored.names());
checkFrameState(fr);
checkFrameState(scored);
assertVecEquals(notMult, fr.vec(1), 0);
assertVecEquals(mult, scored.vec(1), 0);

Frame rescored = Scope.track(pmodel.score(fr));
TestUtil.printOutFrameAsTable(rescored);
Expand All @@ -97,6 +101,8 @@ public void test_simple_transformation_pipeline() {
assertFrameEquals(scored, rescored, 1.6);
checkFrameState(fr);
checkFrameState(rescored);
assertVecEquals(notMult, fr.vec(1), 0);
assertVecEquals(mult, rescored.vec(1), 0);

Frame transformed = Scope.track(pmodel.transform(fr));
TestUtil.printOutFrameAsTable(transformed);
Expand All @@ -105,6 +111,8 @@ public void test_simple_transformation_pipeline() {
assertFrameEquals(scored, transformed, 1.6);
checkFrameState(fr);
checkFrameState(transformed);
assertVecEquals(notMult, fr.vec(1), 0);
assertVecEquals(mult, transformed.vec(1), 0);
}

@Test
Expand Down

0 comments on commit 50c80df

Please sign in to comment.