Skip to content

Commit

Permalink
read from file
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed Nov 22, 2023
1 parent ddd6c2a commit 6120024
Showing 2 changed files with 22 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pieces/TrainTestSplitPiece/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from pydantic import BaseModel, Field
from typing import List
from typing import List, Optional


class InputModel(BaseModel):
"""
Input data for TextSummarizerPiece
"""
data: List[dict] = Field(
data: Optional[List[dict]] = Field(
title="Data",
description="The data to be split.",
json_schema_extra={"from_upstream": "always"}
)
data_path: Optional[str] = Field(
title="Data Path",
default=None,
description="The path to the data to be split.",
json_schema_extra={"from_upstream": "always"}
)
test_data_size: float = Field(
default=0.8,
description="The size (%) of the test data.",
14 changes: 14 additions & 0 deletions pieces/TrainTestSplitPiece/piece.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,24 @@

class TrainTestSplitPiece(BasePiece):

def read_data_from_file(self, path):
"""
Read data from a file.
"""
if path.endswith(".csv"):
return pd.read_csv(path).to_dict(orient='records')
elif path.endswith(".json"):
return pd.read_json(path).to_dict(orient='records')
else:
raise ValueError("File type not supported.")

def piece_function(self, input_data: InputModel):
"""
Split the data into training and test sets.
"""
if input_data.data_path is not None:
input_data.data = self.read_data_from_file(path=input_data.data_path)

df = pd.DataFrame(input_data.data)
if "target" not in df.columns:
raise ValueError("Target column not found in data with name 'target'.")

0 comments on commit 6120024

Please sign in to comment.