Skip to content

Commit

Permalink
Fixing errors with nested subfields
Browse files Browse the repository at this point in the history
  • Loading branch information
ccl-core committed Nov 29, 2024
1 parent 4e30d0d commit 86d51a0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _apply_transform_fn(value: Any, transform: Transform, field: Field) -> Any:
return group
elif transform.json_path is not None:
jsonpath_expression = _parse_jsonpath(transform.json_path)
return next(match.value for match in jsonpath_expression.find(value))
if matches := jsonpath_expression.find(value):
return next(match.value for match in matches)
return None
elif transform.format is not None:
if field.data_type is pd.Timestamp:
return pd.Timestamp(value).strftime(transform.format)
Expand Down Expand Up @@ -105,7 +107,6 @@ def _cast_value(ctx: Context, value: Any, data_type: type | term.URIRef | None):
else:
return data_type(value)


def _to_bytes(value: Any) -> bytes:
"""Casts the value `value` to bytes."""
if isinstance(value, bytes):
Expand Down Expand Up @@ -185,16 +186,15 @@ class ReadFields(Operation):

node: RecordSet

def _fields(self) -> list[Field]:
"""Extracts all fields (i.e., direct fields without subFields and subFields)."""
fields: list[Field] = []
for field in self.node.fields:
def _get_fields(self, fields: list[Field]):
"""Extracts all leaves fields (i.e., including subFields)."""
all_fields: list[Field] = []
for field in fields:
if field.sub_fields:
for sub_field in field.sub_fields:
fields.append(sub_field)
all_fields.extend(self._get_fields(field.sub_fields))
else:
fields.append(field)
return fields
all_fields.append(field)
return all_fields

def call(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]:
"""See class' docstring."""
Expand All @@ -203,7 +203,9 @@ def call(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]:
for _, row in df.iterrows():
yield dict(row)
return
fields = self._fields()
fields = self._get_fields(self.node.fields)
if not fields:
raise ValueError(f"RecordSet {self.node.uuid} has no fields!")
for field in fields:
df = _extract_value(df, field)

Expand All @@ -226,9 +228,12 @@ def _get_result(row):
if _is_na(value):
value = None
elif is_repeated:
value = [
_cast_value(self.node.ctx, v, field.data_type) for v in value
]
try:
value = [
_cast_value(self.node.ctx, v, field.data_type) for v in value
]
except TypeError:
value = value
else:
value = _cast_value(self.node.ctx, value, field.data_type)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from etils import epath
from rdflib.namespace import SDO
import requests
from typing import Any
from typing_extensions import Self

from mlcroissant._src.core import constants
Expand Down Expand Up @@ -299,6 +300,15 @@ class Metadata(Node):
url=constants.ML_COMMONS_RAI_DATA_RELEASE_MAINTENANCE_PLAN,
)

def _define_field_parents(self, fields: list[Field], parents: list[Any]):
"""Recursively populate the field's and subfield's parents."""
for field in fields:
field.parents = [self] + parents
if field.sub_fields:
field_parents = parents + [field]
self._define_field_parents(fields = field.sub_fields, parents=field_parents)


def __post_init__(self):
"""Checks arguments of the node and setup ID."""
Node.__post_init__(self)
Expand All @@ -307,10 +317,7 @@ def __post_init__(self):
node.parents = [self]
for record_set in self.record_sets:
record_set.parents = [self]
for field in record_set.fields:
field.parents = [self, record_set]
for sub_field in field.sub_fields:
sub_field.parents = [self, record_set, field]
self._define_field_parents(record_set.fields, parents=[record_set])

# Back-fill the graph in every node.
self.ctx.graph = from_nodes_to_graph(self)
Expand Down

0 comments on commit 86d51a0

Please sign in to comment.