Skip to content

Commit

Permalink
add standard scaler and move heavy piece to not buildd in dev
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed Nov 22, 2023
1 parent ad0a254 commit d489280
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 1 deletion.
18 changes: 18 additions & 0 deletions pieces/StandardScalerPiece/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"name": "StandardScalerPiece",
"description": "Apply StandardScaler to the data",
"dependency": {
"requirements_file": "requirements_0.txt"
},
"tags": [
"preprocessing",
"scaler"
],
"style": {
"node_label": "Standard Scaler",
"node_style": {
"backgroundColor": "#b3cde8"
},
"icon_class_name": "icon-park-outline:split"
}
}
26 changes: 26 additions & 0 deletions pieces/StandardScalerPiece/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import List


class InputModel(BaseModel):
"""
Input data for TextSummarizerPiece
"""
train_data: List[dict] = Field(
title="Train Data",
description="The train data to be scaled.",
json_schema_extra={"from_upstream": "always"}
)
test_data: List[dict] = Field(
title="Test Data",
description="The test data to be scaled.",
json_schema_extra={"from_upstream": "always"}
)


class OutputModel(BaseModel):
"""
Output data for TextSummarizerPiece
"""
train_data: List[dict]
test_data: List[dict]
29 changes: 29 additions & 0 deletions pieces/StandardScalerPiece/piece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from domino.base_piece import BasePiece
from .models import InputModel, OutputModel
import pandas as pd
from sklearn.preprocessing import StandardScaler


class StandardScalerPiece(BasePiece):

def piece_function(self, input_data: InputModel):
df_train = pd.DataFrame(input_data.train_data)
df_test = pd.DataFrame(input_data.test_data)

if "target" not in df_train.columns or "target" not in df_test.columns:
raise ValueError("Target column not found in data with name 'target'.")


scaler = StandardScaler()
scaler.fit(df_train.drop('target', axis=1))
X_train = scaler.transform(df_train.drop('target', axis=1))
X_test = scaler.transform(df_test.drop('target', axis=1))

df_train_scaled = pd.DataFrame(X_train, columns=df_train.drop('target', axis=1).columns)
df_train_scaled['target'] = df_train['target']
df_test_scaled = pd.DataFrame(X_test, columns=df_test.drop('target', axis=1).columns)
df_test_scaled['target'] = df_test['target']

return OutputModel(train_data=df_train_scaled.to_dict(orient='records'), test_data=df_test_scaled.to_dict(orient='records'))


2 changes: 1 addition & 1 deletion pieces/TrainTestSplitPiece/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"requirements_file": "requirements_0.txt"
},
"tags": [
"default",
"preprocessing",
"datasets"
],
"style": {
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit d489280

Please sign in to comment.