Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made minor edits to match the proposed Estimator transformSchema mehtod to the existing Model transformSchema method. #61

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,13 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores

/**
* Validates the input schema and transforms it into the output schema. It validates that the
* input DataFrame has a $(featuresCol) of the correct type. In this case, the output schema appends
* the output columns to the input schema.
* input DataFrame has a $(featuresCol) of the correct type and appends the output columns to
* the input schema. It also ensures that the input DataFrame does not already have
* $(predictionCol) or $(scoreCol) columns, as they will be created during the fitting process.
*
* @param schema The schema of the DataFrame containing the data to be fit.
* @return The schema of the DataFrame containing the data to be fit.
* @return The schema of the DataFrame containing the data to be fit, with the additional
* $(predictionCol) and $(scoreCol) columns added.
*/
override def transformSchema(schema: StructType): StructType = {

Expand All @@ -200,16 +202,14 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores
require(schema($(featuresCol)).dataType == VectorType,
s"Input column ${$(featuresCol)} is not of required type ${VectorType}")

val outputFields: Array[StructField] = schema.fields ++ Array(
StructField(
name = s"$predictionCol",
dataType = DoubleType
),
StructField(
name = s"$scoreCol",
dataType = DoubleType
)
)
require(!schema.fieldNames.contains($(predictionCol)),
s"Output column ${$(predictionCol)} already exists.")
require(!schema.fieldNames.contains($(scoreCol)),
s"Output column ${$(scoreCol)} already exists.")

val outputFields = schema.fields :+
StructField($(predictionCol), DoubleType, nullable = false) :+
StructField($(scoreCol), DoubleType, nullable = false)

StructType(outputFields)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class IsolationForestModel(

/**
* Validates the input schema and transforms it into the output schema. It validates that the
* input DataFrame has a $(featuresCol) of the correct type. It also ensures that the input
* DataFrame does not already have $(predictionCol) or $(scoreCol) columns, as they will be
* created during the fitting process.
* input DataFrame has a $(featuresCol) of the correct type and appends the output columns to
* the input schema. It also ensures that the input DataFrame does not already have
* $(predictionCol) or $(scoreCol) columns, as they will be created during the fitting process.
*
* @param schema The schema of the DataFrame containing the data to be fit.
* @return The schema of the DataFrame containing the data to be fit, with the additional
Expand Down
Loading