Skip to content

Commit

Permalink
Made minor edits to match the proposed Estimator transformSchema meth…
Browse files Browse the repository at this point in the history
…od to the existing Model transformSchema method. (#61)
  • Loading branch information
jverbus authored Dec 17, 2024
1 parent bc58518 commit bd9ed50
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,13 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores

/**
* Validates the input schema and transforms it into the output schema. It validates that the
* input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema appends
* the output columns to the input schema.
* input DataFrame has a $(featuresCol) of the correct type and appends the output columns to
* the input schema. It also ensures that the input DataFrame does not already have
* $(predictionCol) or $(scoreCol) columns, as they will be created during the fitting process.
*
* @param schema The schema of the DataFrame containing the data to be fit.
* @return The schema of the DataFrame containing the data to be fit.
* @return The schema of the DataFrame containing the data to be fit, with the additional
* $(predictionCol) and $(scoreCol) columns added.
*/
override def transformSchema(schema: StructType): StructType = {

Expand All @@ -200,16 +202,14 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores
require(schema($(featuresCol)).dataType == VectorType,
s"Input column ${$(featuresCol)} is not of required type ${VectorType}")

val outputFields: Array[StructField] = schema.fields ++ Array(
StructField(
name = s"$predictionCol",
dataType = DoubleType
),
StructField(
name = s"$scoreCol",
dataType = DoubleType
)
)
require(!schema.fieldNames.contains($(predictionCol)),
s"Output column ${$(predictionCol)} already exists.")
require(!schema.fieldNames.contains($(scoreCol)),
s"Output column ${$(scoreCol)} already exists.")

val outputFields = schema.fields :+
StructField($(predictionCol), DoubleType, nullable = false) :+
StructField($(scoreCol), DoubleType, nullable = false)

StructType(outputFields)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class IsolationForestModel(

/**
* Validates the input schema and transforms it into the output schema. It validates that the
* input DataFrame has a $(featuresCol) of the correct type. It also ensures that the input
* DataFrame does not already have $(predictionCol) or $(scoreCol) columns, as they will be
* created during the fitting process.
* input DataFrame has a $(featuresCol) of the correct type and appends the output columns to
* the input schema. It also ensures that the input DataFrame does not already have
* $(predictionCol) or $(scoreCol) columns, as they will be created during the fitting process.
*
* @param schema The schema of the DataFrame containing the data to be fit.
* @return The schema of the DataFrame containing the data to be fit, with the additional
Expand Down

0 comments on commit bd9ed50

Please sign in to comment.