Skip to content

Commit

Permalink
Use ids to reference a field or a node. (#744)
Browse files Browse the repository at this point in the history
Names are optional (so they can also not be specified). We should use
ids to reference nodes.
  • Loading branch information
ccl-core authored Sep 24, 2024
1 parent 6c79dc0 commit 7574ded
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 16 deletions.
4 changes: 2 additions & 2 deletions datasets/1.0/audio_test/output/records.jsonl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
{"audio": "(array([-2.8619270e-13, -1.7014803e-13, 2.7065091e-14, ...,\n -6.4091455e-06, -3.7976279e-06, 2.7510678e-06], dtype=float32), 22050)"}
{"audio": "(array([5.8726583e-14, 1.3397688e-13, 2.2199205e-13, ..., 4.2678180e-04,\n 1.9029720e-04, 2.7079385e-04], dtype=float32), 22050)"}
{"records/audio": "(array([-2.8619270e-13, -1.7014803e-13, 2.7065091e-14, ...,\n -6.4091455e-06, -3.7976279e-06, 2.7510678e-06], dtype=float32), 22050)"}
{"records/audio": "(array([5.8726583e-14, 1.3397688e-13, 2.2199205e-13, ..., 4.2678180e-04,\n 1.9029720e-04, 2.7079385e-04], dtype=float32), 22050)"}
6 changes: 3 additions & 3 deletions python/mlcroissant/mlcroissant/_src/core/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def print_graph_traversal(graph: nx.Graph):
print("--- Graph traversal ---")
for start, end, _ in nx.edge_bfs(graph):
for node in [start, end]:
if node.name not in visited:
print(f"Visited: {node.name}")
visited[node.name] = True
if node.id not in visited:
print(f"Visited: {node.id}")
visited[node.id] = True
print("Done traversing the graph.")
3 changes: 2 additions & 1 deletion python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def load_records_and_test_equality(
):
filters_command = ""
if filters:
filters_command = f" --filters '{filters}'"
filters_command = str(filters).replace("'", '"')
filters_command = f" --filters '{filters_command}'"
print(
"If this test fails, update JSONL with: `mlcroissant load --jsonld"
f" ../../datasets/{version}/{dataset_name} --record_set"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def get_hash(url: str) -> str:
def get_download_filepath(node: FileObject) -> epath.Path:
"""Retrieves the download filepath of an URL."""
ctx = node.ctx
if node.name in ctx.mapping:
return ctx.mapping[node.name]
if node.id in ctx.mapping:
return ctx.mapping[node.id]
url = node.content_url
if url and not is_url(url) and not node.contained_in:
if ctx.folder is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_get_hash_obj_sha256():
def test_get_download_filepath():
ctx = Context()
# With mapping
ctx.mapping = {"foo": epath.Path("/bar/foo")}
ctx.mapping = {"file-object": epath.Path("/bar/foo")}
node = FileObject(
ctx=ctx, name="foo", id="file-object", content_url="http://foo", sha256="12345"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def _get_result(row):
]
else:
value = _cast_value(self.node.ctx, value, field.data_type)
result[field.name] = value
if self.node.ctx.is_v0():
result[field.name] = value
else:
result[field.id] = value
return result

chunk_size = 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,22 @@ def test_extract_lines(separator):
read_field = ReadFields(operations=Operations(), node=record_sets[0])
df = pd.DataFrame({FileProperty.filepath: [path]})
expected = [
{"line_number": 0, "line": b"bon jour ", "filename": b"file"},
{"line_number": 1, "line": b"", "filename": b"file"},
{"line_number": 2, "line": b" h\xc3\xa9llo ", "filename": b"file"},
{"line_number": 3, "line": b"hallo ", "filename": b"file"},
{
"main/line_number": 0,
"main/line": b"bon jour ",
"main/filename": b"file",
},
{"main/line_number": 1, "main/line": b"", "main/filename": b"file"},
{
"main/line_number": 2,
"main/line": b" h\xc3\xa9llo ",
"main/filename": b"file",
},
{
"main/line_number": 3,
"main/line": b"hallo ",
"main/filename": b"file",
},
]
result = list(read_field.call(df))
assert result == expected
Expand Down
2 changes: 1 addition & 1 deletion python/mlcroissant/mlcroissant/scripts/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def load(
raise ValueError("--filters should be a valid dict[str, str]") from e
dataset = mlc.Dataset(jsonld, debug=debug, mapping=file_mapping)
if record_set is None:
record_sets = ", ".join([f"`{rs.name}`" for rs in dataset.metadata.record_sets])
record_sets = ", ".join([f"`{rs.id}`" for rs in dataset.metadata.record_sets])
raise ValueError(f"--record_set flag should have a value in {record_sets}")
records = dataset.records(record_set, filters=parsed_filters)
generate_all_records = num_records == -1
Expand Down
2 changes: 1 addition & 1 deletion python/mlcroissant/recipes/bounding-boxes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
"metadata": {},
"outputs": [],
"source": [
"image_id, bbox = record[\"image_id\"], record[\"bbox\"]\n",
"image_id, bbox = record[\"images_with_bounding_box/image_id\"], record[\"images_with_bounding_box/bbox\"]\n",
"url = f\"http://images.cocodataset.org/val2014/COCO_val2014_{image_id:012d}.jpg\"\n",
"\n",
"# Download the image\n",
Expand Down

0 comments on commit 7574ded

Please sign in to comment.