From 7574dedc6e77fe13097f9a2ae588bc3d55fcac7a Mon Sep 17 00:00:00 2001 From: ccl-core <91942859+ccl-core@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:25:16 +0200 Subject: [PATCH] Use ids to reference a field or a node. (#744) Names are optional (so they can also not be specified). We should use ids to reference nodes. --- datasets/1.0/audio_test/output/records.jsonl | 4 ++-- .../mlcroissant/_src/core/graphs/utils.py | 6 +++--- .../mlcroissant/_src/datasets_test.py | 3 ++- .../operation_graph/operations/download.py | 4 ++-- .../operations/download_test.py | 2 +- .../_src/operation_graph/operations/field.py | 5 ++++- .../operation_graph/operations/field_test.py | 20 +++++++++++++++---- .../mlcroissant/mlcroissant/scripts/load.py | 2 +- .../mlcroissant/recipes/bounding-boxes.ipynb | 2 +- 9 files changed, 32 insertions(+), 16 deletions(-) diff --git a/datasets/1.0/audio_test/output/records.jsonl b/datasets/1.0/audio_test/output/records.jsonl index 75ed4697d..541e63afc 100644 --- a/datasets/1.0/audio_test/output/records.jsonl +++ b/datasets/1.0/audio_test/output/records.jsonl @@ -1,2 +1,2 @@ -{"audio": "(array([-2.8619270e-13, -1.7014803e-13, 2.7065091e-14, ...,\n -6.4091455e-06, -3.7976279e-06, 2.7510678e-06], dtype=float32), 22050)"} -{"audio": "(array([5.8726583e-14, 1.3397688e-13, 2.2199205e-13, ..., 4.2678180e-04,\n 1.9029720e-04, 2.7079385e-04], dtype=float32), 22050)"} +{"records/audio": "(array([-2.8619270e-13, -1.7014803e-13, 2.7065091e-14, ...,\n -6.4091455e-06, -3.7976279e-06, 2.7510678e-06], dtype=float32), 22050)"} +{"records/audio": "(array([5.8726583e-14, 1.3397688e-13, 2.2199205e-13, ..., 4.2678180e-04,\n 1.9029720e-04, 2.7079385e-04], dtype=float32), 22050)"} diff --git a/python/mlcroissant/mlcroissant/_src/core/graphs/utils.py b/python/mlcroissant/mlcroissant/_src/core/graphs/utils.py index d0382ea39..df387ffee 100644 --- a/python/mlcroissant/mlcroissant/_src/core/graphs/utils.py +++ b/python/mlcroissant/mlcroissant/_src/core/graphs/utils.py @@ -32,7 +32,7 @@ def print_graph_traversal(graph: nx.Graph): print("--- Graph traversal ---") for start, end, _ in nx.edge_bfs(graph): for node in [start, end]: - if node.name not in visited: - print(f"Visited: {node.name}") - visited[node.name] = True + if node.id not in visited: + print(f"Visited: {node.id}") + visited[node.id] = True print("Done traversing the graph.") diff --git a/python/mlcroissant/mlcroissant/_src/datasets_test.py b/python/mlcroissant/mlcroissant/_src/datasets_test.py index 0fae9a688..97cf22cbf 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets_test.py +++ b/python/mlcroissant/mlcroissant/_src/datasets_test.py @@ -91,7 +91,8 @@ def load_records_and_test_equality( ): filters_command = "" if filters: - filters_command = f" --filters '{filters}'" + filters_command = str(filters).replace("'", '"') + filters_command = f" --filters '{filters_command}'" print( "If this test fails, update JSONL with: `mlcroissant load --jsonld" f" ../../datasets/{version}/{dataset_name} --record_set" diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download.py index 24f9f6cad..8aafe3525 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download.py @@ -40,8 +40,8 @@ def get_hash(url: str) -> str: def get_download_filepath(node: FileObject) -> epath.Path: """Retrieves the download filepath of an URL.""" ctx = node.ctx - if node.name in ctx.mapping: - return ctx.mapping[node.name] + if node.id in ctx.mapping: + return ctx.mapping[node.id] url = node.content_url if url and not is_url(url) and not node.contained_in: if ctx.folder is None: diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download_test.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download_test.py index f4e02b366..7b35c7d5f 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download_test.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/download_test.py @@ -106,7 +106,7 @@ def test_get_hash_obj_sha256(): def test_get_download_filepath(): ctx = Context() # With mapping - ctx.mapping = {"foo": epath.Path("/bar/foo")} + ctx.mapping = {"file-object": epath.Path("/bar/foo")} node = FileObject( ctx=ctx, name="foo", id="file-object", content_url="http://foo", sha256="12345" ) diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py index c69cba34f..274d876f2 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py @@ -185,7 +185,10 @@ def _get_result(row): ] else: value = _cast_value(self.node.ctx, value, field.data_type) - result[field.name] = value + if self.node.ctx.is_v0(): + result[field.name] = value + else: + result[field.id] = value return result chunk_size = 100 diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py index 48b97e3ba..3ed6cdd12 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py @@ -160,10 +160,22 @@ def test_extract_lines(separator): read_field = ReadFields(operations=Operations(), node=record_sets[0]) df = pd.DataFrame({FileProperty.filepath: [path]}) expected = [ - {"line_number": 0, "line": b"bon jour ", "filename": b"file"}, - {"line_number": 1, "line": b"", "filename": b"file"}, - {"line_number": 2, "line": b" h\xc3\xa9llo ", "filename": b"file"}, - {"line_number": 3, "line": b"hallo ", "filename": b"file"}, + { + "main/line_number": 0, + "main/line": b"bon jour ", + "main/filename": b"file", + }, + {"main/line_number": 1, "main/line": b"", "main/filename": b"file"}, + { + "main/line_number": 2, + "main/line": b" h\xc3\xa9llo ", + "main/filename": b"file", + }, + { + "main/line_number": 3, + "main/line": b"hallo ", + "main/filename": b"file", + }, ] result = list(read_field.call(df)) assert result == expected diff --git a/python/mlcroissant/mlcroissant/scripts/load.py b/python/mlcroissant/mlcroissant/scripts/load.py index 358c84a32..8b0cb3076 100644 --- a/python/mlcroissant/mlcroissant/scripts/load.py +++ b/python/mlcroissant/mlcroissant/scripts/load.py @@ -121,7 +121,7 @@ def load( raise ValueError("--filters should be a valid dict[str, str]") from e dataset = mlc.Dataset(jsonld, debug=debug, mapping=file_mapping) if record_set is None: - record_sets = ", ".join([f"`{rs.name}`" for rs in dataset.metadata.record_sets]) + record_sets = ", ".join([f"`{rs.id}`" for rs in dataset.metadata.record_sets]) raise ValueError(f"--record_set flag should have a value in {record_sets}") records = dataset.records(record_set, filters=parsed_filters) generate_all_records = num_records == -1 diff --git a/python/mlcroissant/recipes/bounding-boxes.ipynb b/python/mlcroissant/recipes/bounding-boxes.ipynb index f43f52bf2..bf77968f4 100644 --- a/python/mlcroissant/recipes/bounding-boxes.ipynb +++ b/python/mlcroissant/recipes/bounding-boxes.ipynb @@ -165,7 +165,7 @@ "metadata": {}, "outputs": [], "source": [ - "image_id, bbox = record[\"image_id\"], record[\"bbox\"]\n", + "image_id, bbox = record[\"images_with_bounding_box/image_id\"], record[\"images_with_bounding_box/bbox\"]\n", "url = f\"http://images.cocodataset.org/val2014/COCO_val2014_{image_id:012d}.jpg\"\n", "\n", "# Download the image\n",