From f61607b26823edb4f2e2dbdda3bceb2b3e2eea57 Mon Sep 17 00:00:00 2001 From: Marcus Rosti Date: Wed, 27 Nov 2024 14:29:14 -0800 Subject: [PATCH 1/2] Appends prediction columns to transform schema --- .../relevance/isolationforest/IsolationForest.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala index e885339..7429573 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala @@ -7,7 +7,7 @@ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.ml.Estimator -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.Dataset import org.apache.spark.{HashPartitioner, TaskContext} @@ -200,7 +200,16 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores require(schema($(featuresCol)).dataType == VectorType, s"Input column ${$(featuresCol)} is not of required type ${VectorType}") - val outputFields = schema.fields + val outputFields: Array[StructField] = schema.fields ++ Array( + StructField( + name = s"$predictionCol", + dataType = DoubleType + ), + StructField( + name = s"$scoreCol", + dataType = DoubleType + ) + ) StructType(outputFields) } From 4d49dca7b1b6292c291a3d876c4465b4ba776ef1 Mon Sep 17 00:00:00 2001 From: Marcus Rosti Date: Wed, 27 Nov 2024 14:32:06 -0800 Subject: [PATCH 2/2] fixes the comment --- .../linkedin/relevance/isolationforest/IsolationForest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala index 7429573..1864926 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala @@ -187,8 +187,8 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores /** * Validates the input schema and transforms it into the output schema. It validates that the - * input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema is - * identical to the input schema. + * input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema appends + * the output columns to the input schema. * * @param schema The schema of the DataFrame containing the data to be fit. * @return The schema of the DataFrame containing the data to be fit.