diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala index 1864926..42864da 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForest.scala @@ -187,11 +187,13 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores /** * Validates the input schema and transforms it into the output schema. It validates that the - * input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema appends - * the output columns to the input schema. + * input DataFrame has a $(featuresCol) of the correct type and appends the output columns to + * the input schema. It also ensures that the input DataFrame does not already have + * $(predictionCol) or $(scoreCol) columns, as they will be created during the fitting process. * * @param schema The schema of the DataFrame containing the data to be fit. - * @return The schema of the DataFrame containing the data to be fit. + * @return The schema of the DataFrame containing the data to be fit, with the additional + * $(predictionCol) and $(scoreCol) columns added. */ override def transformSchema(schema: StructType): StructType = { @@ -200,16 +202,14 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores require(schema($(featuresCol)).dataType == VectorType, s"Input column ${$(featuresCol)} is not of required type ${VectorType}") - val outputFields: Array[StructField] = schema.fields ++ Array( - StructField( - name = s"$predictionCol", - dataType = DoubleType - ), - StructField( - name = s"$scoreCol", - dataType = DoubleType - ) - ) + require(!schema.fieldNames.contains($(predictionCol)), + s"Output column ${$(predictionCol)} already exists.") + require(!schema.fieldNames.contains($(scoreCol)), + s"Output column ${$(scoreCol)} already exists.") + + val outputFields = schema.fields :+ + StructField($(predictionCol), DoubleType, nullable = false) :+ + StructField($(scoreCol), DoubleType, nullable = false) StructType(outputFields) } diff --git a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala index 30c8c29..11b7518 100644 --- a/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala +++ b/isolation-forest/src/main/scala/com/linkedin/relevance/isolationforest/IsolationForestModel.scala @@ -89,9 +89,9 @@ class IsolationForestModel( /** * Validates the input schema and transforms it into the output schema. It validates that the - * input DataFrame has a $(featuresCol) of the correct type. It also ensures that the input - * DataFrame does not already have $(predictionCol) or $(scoreCol) columns, as they will be - * created during the fitting process. + * input DataFrame has a $(featuresCol) of the correct type and appends the output columns to + * the input schema. It also ensures that the input DataFrame does not already have + * $(predictionCol) or $(scoreCol) columns, as they will be created during the fitting process. * * @param schema The schema of the DataFrame containing the data to be fit. * @return The schema of the DataFrame containing the data to be fit, with the additional