From 86d51a0efe784e54d748e22717d3e0ff84918def Mon Sep 17 00:00:00 2001 From: ccl-core Date: Fri, 29 Nov 2024 14:54:22 +0000 Subject: [PATCH] Fixing errors with nested subfields --- .../_src/operation_graph/operations/field.py | 33 +++++++++++-------- .../_src/structure_graph/nodes/metadata.py | 15 ++++++--- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py index d945c1f8d..964ead78f 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py @@ -53,7 +53,9 @@ def _apply_transform_fn(value: Any, transform: Transform, field: Field) -> Any: return group elif transform.json_path is not None: jsonpath_expression = _parse_jsonpath(transform.json_path) - return next(match.value for match in jsonpath_expression.find(value)) + if matches := jsonpath_expression.find(value): + return next(match.value for match in matches) + return None elif transform.format is not None: if field.data_type is pd.Timestamp: return pd.Timestamp(value).strftime(transform.format) @@ -105,7 +107,6 @@ def _cast_value(ctx: Context, value: Any, data_type: type | term.URIRef | None): else: return data_type(value) - def _to_bytes(value: Any) -> bytes: """Casts the value `value` to bytes.""" if isinstance(value, bytes): @@ -185,16 +186,15 @@ class ReadFields(Operation): node: RecordSet - def _fields(self) -> list[Field]: - """Extracts all fields (i.e., direct fields without subFields and subFields).""" - fields: list[Field] = [] - for field in self.node.fields: + def _get_fields(self, fields: list[Field]): + """Extracts all leaves fields (i.e., including subFields).""" + all_fields: list[Field] = [] + for field in fields: if field.sub_fields: - for sub_field in field.sub_fields: - fields.append(sub_field) + all_fields.extend(self._get_fields(field.sub_fields)) else: - fields.append(field) - return fields + all_fields.append(field) + return all_fields def call(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]: """See class' docstring.""" @@ -203,7 +203,9 @@ def call(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]: for _, row in df.iterrows(): yield dict(row) return - fields = self._fields() + fields = self._get_fields(self.node.fields) + if not fields: + raise ValueError(f"RecordSet {self.node.uuid} has no fields!") for field in fields: df = _extract_value(df, field) @@ -226,9 +228,12 @@ def _get_result(row): if _is_na(value): value = None elif is_repeated: - value = [ - _cast_value(self.node.ctx, v, field.data_type) for v in value - ] + try: + value = [ + _cast_value(self.node.ctx, v, field.data_type) for v in value + ] + except TypeError: + value = value else: value = _cast_value(self.node.ctx, value, field.data_type) diff --git a/python/mlcroissant/mlcroissant/_src/structure_graph/nodes/metadata.py b/python/mlcroissant/mlcroissant/_src/structure_graph/nodes/metadata.py index 8d81ed0e2..70b0c69b0 100644 --- a/python/mlcroissant/mlcroissant/_src/structure_graph/nodes/metadata.py +++ b/python/mlcroissant/mlcroissant/_src/structure_graph/nodes/metadata.py @@ -5,6 +5,7 @@ from etils import epath from rdflib.namespace import SDO import requests +from typing import Any from typing_extensions import Self from mlcroissant._src.core import constants @@ -299,6 +300,15 @@ class Metadata(Node): url=constants.ML_COMMONS_RAI_DATA_RELEASE_MAINTENANCE_PLAN, ) + def _define_field_parents(self, fields: list[Field], parents: list[Any]): + """Recursively populate the field's and subfield's parents.""" + for field in fields: + field.parents = [self] + parents + if field.sub_fields: + field_parents = parents + [field] + self._define_field_parents(fields = field.sub_fields, parents=field_parents) + + def __post_init__(self): """Checks arguments of the node and setup ID.""" Node.__post_init__(self) @@ -307,10 +317,7 @@ def __post_init__(self): node.parents = [self] for record_set in self.record_sets: record_set.parents = [self] - for field in record_set.fields: - field.parents = [self, record_set] - for sub_field in field.sub_fields: - sub_field.parents = [self, record_set, field] + self._define_field_parents(record_set.fields, parents=[record_set]) # Back-fill the graph in every node. self.ctx.graph = from_nodes_to_graph(self)