Skip to content

Commit

Permalink
[TypeTransformer] Support frozen dataclasses (#2823)
Browse files Browse the repository at this point in the history
* [TypeTransformer] Support frozen dataclasses

Signed-off-by: Future-Outlier <[email protected]>

* break tests

Signed-off-by: Future-Outlier <[email protected]>

* use the base class object's magic method, Suggested by Eduardo, this is great

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Eduardo Apolinario  <[email protected]>

* add tests

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
Future-Outlier and eapolinario authored Oct 23, 2024
1 parent 3fc51af commit f79f51d
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.
elif dataclasses.is_dataclass(python_type):
for field in dataclasses.fields(python_type):
val = python_val.__getattribute__(field.name)
python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val))
object.__setattr__(python_val, field.name, self._fix_structured_dataset_type(field.type, val))
return python_val

def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any:
Expand Down Expand Up @@ -718,7 +718,7 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t
dataclass_attributes = typing.get_type_hints(python_type)
for n, t in dataclass_attributes.items():
val = python_val.__getattribute__(n)
python_val.__setattr__(n, self._make_dataclass_serializable(val, t))
object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t))
return python_val

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
Expand Down Expand Up @@ -761,7 +761,7 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An
# Thus we will have to walk the given dataclass and typecast values to int, where expected.
for f in dataclasses.fields(dc_type):
val = getattr(dc, f.name)
setattr(dc, f.name, self._fix_val_int(f.type, val))
object.__setattr__(dc, f.name, self._fix_val_int(f.type, val))

return dc

Expand Down
167 changes: 167 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,170 @@ def my_task(dc: DC) -> DC:
return dc

my_task(dc=DC())

def test_frozen_dataclass():
@dataclass(frozen=True)
class FrozenDataclass:
a: int = 1
b: float = 2.0
c: bool = True
d: str = "hello"

@task
def t1(dc: FrozenDataclass) -> (int, float, bool, str):
return dc.a, dc.b, dc.c, dc.d

a, b, c, d = t1(dc=FrozenDataclass())
assert a == 1
assert b == 2.0
assert c == True
assert d == "hello"

def test_pure_frozen_dataclasses_with_python_types():
@dataclass(frozen=True)
class DC:
string: Optional[str] = None

@dataclass(frozen=True)
class DCWithOptional:
string: Optional[str] = None
dc: Optional[DC] = None
list_dc: Optional[List[DC]] = None
list_list_dc: Optional[List[List[DC]]] = None
dict_dc: Optional[Dict[str, DC]] = None
dict_dict_dc: Optional[Dict[str, Dict[str, DC]]] = None
dict_list_dc: Optional[Dict[str, List[DC]]] = None
list_dict_dc: Optional[List[Dict[str, DC]]] = None

@task
def t1() -> DCWithOptional:
return DCWithOptional(string="a", dc=DC(string="b"),
list_dc=[DC(string="c"), DC(string="d")],
list_list_dc=[[DC(string="e"), DC(string="f")]],
list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")},
{"k": DC(string="l"), "m": DC(string="n")}],
dict_dc={"o": DC(string="p"), "q": DC(string="r")},
dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}},
dict_list_dc={"x": [DC(string="y"), DC(string="z")],
"aa": [DC(string="bb"), DC(string="cc")]},)

@task
def t2() -> DCWithOptional:
return DCWithOptional()

output = DCWithOptional(string="a", dc=DC(string="b"),
list_dc=[DC(string="c"), DC(string="d")],
list_list_dc=[[DC(string="e"), DC(string="f")]],
list_dict_dc=[{"g": DC(string="h"), "i": DC(string="j")},
{"k": DC(string="l"), "m": DC(string="n")}],
dict_dc={"o": DC(string="p"), "q": DC(string="r")},
dict_dict_dc={"s": {"t": DC(string="u"), "v": DC(string="w")}},
dict_list_dc={"x": [DC(string="y"), DC(string="z")],
"aa": [DC(string="bb"), DC(string="cc")]}, )

dc1 = t1()
dc2 = t2()

assert dc1 == output
assert dc2.string is None
assert dc2.dc is None

DataclassTransformer().assert_type(DCWithOptional, dc1)
DataclassTransformer().assert_type(DCWithOptional, dc2)

def test_pure_frozen_dataclasses_with_flyte_types(local_dummy_txt_file, local_dummy_directory):
@dataclass(frozen=True)
class FlyteTypes:
flytefile: Optional[FlyteFile] = None
flytedir: Optional[FlyteDirectory] = None
structured_dataset: Optional[StructuredDataset] = None

@dataclass(frozen=True)
class NestedFlyteTypes:
flytefile: Optional[FlyteFile] = None
flytedir: Optional[FlyteDirectory] = None
structured_dataset: Optional[StructuredDataset] = None
flyte_types: Optional[FlyteTypes] = None
list_flyte_types: Optional[List[FlyteTypes]] = None
dict_flyte_types: Optional[Dict[str, FlyteTypes]] = None
optional_flyte_types: Optional[FlyteTypes] = None

@task
def pass_and_return_flyte_types(nested_flyte_types: NestedFlyteTypes) -> NestedFlyteTypes:
return nested_flyte_types

@task
def generate_sd() -> StructuredDataset:
return StructuredDataset(
uri="s3://my-s3-bucket/data/test_sd",
file_format="parquet")

@task
def create_local_dir(path: str) -> FlyteDirectory:
return FlyteDirectory(path=path)

@task
def create_local_dir_by_str(path: str) -> FlyteDirectory:
return path

@task
def create_local_file(path: str) -> FlyteFile:
return FlyteFile(path=path)

@task
def create_local_file_with_str(path: str) -> FlyteFile:
return path

@task
def generate_nested_flyte_types(local_file: FlyteFile, local_dir: FlyteDirectory, sd: StructuredDataset,
local_file_by_str: FlyteFile,
local_dir_by_str: FlyteDirectory, ) -> NestedFlyteTypes:
ft = FlyteTypes(
flytefile=local_file,
flytedir=local_dir,
structured_dataset=sd,
)

return NestedFlyteTypes(
flytefile=local_file,
flytedir=local_dir,
structured_dataset=sd,
flyte_types=FlyteTypes(
flytefile=local_file_by_str,
flytedir=local_dir_by_str,
structured_dataset=sd,
),
list_flyte_types=[ft, ft, ft],
dict_flyte_types={"a": ft, "b": ft, "c": ft},
)

@workflow
def nested_dc_wf(txt_path: str, dir_path: str) -> NestedFlyteTypes:
local_file = create_local_file(path=txt_path)
local_dir = create_local_dir(path=dir_path)
local_file_by_str = create_local_file_with_str(path=txt_path)
local_dir_by_str = create_local_dir_by_str(path=dir_path)
sd = generate_sd()
nested_flyte_types = generate_nested_flyte_types(
local_file=local_file,
local_dir=local_dir,
local_file_by_str=local_file_by_str,
local_dir_by_str=local_dir_by_str,
sd=sd
)
old_flyte_types = pass_and_return_flyte_types(nested_flyte_types=nested_flyte_types)
return pass_and_return_flyte_types(nested_flyte_types=old_flyte_types)

@task
def get_empty_nested_type() -> NestedFlyteTypes:
return NestedFlyteTypes()

@workflow
def empty_nested_dc_wf() -> NestedFlyteTypes:
return get_empty_nested_type()

nested_flyte_types = nested_dc_wf(txt_path=local_dummy_txt_file, dir_path=local_dummy_directory)
DataclassTransformer().assert_type(NestedFlyteTypes, nested_flyte_types)

empty_nested_flyte_types = empty_nested_dc_wf()
DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)

0 comments on commit f79f51d

Please sign in to comment.