Skip to content

Commit

Permalink
Release constraints on names for croissant >= 1.0 (#579)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccl-core authored Mar 4, 2024
1 parent b4877e2 commit 39eb1dc
Show file tree
Hide file tree
Showing 24 changed files with 189 additions and 298 deletions.
23 changes: 18 additions & 5 deletions python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,17 @@ def get_error_msg(folder: epath.Path):
# Distribution.
"distribution_bad_contained_in",
"distribution_bad_type",
# When the name is missing, the context should still appear without the name.
"distribution_missing_name",
"distribution_missing_encoding_format",
"distribution_missing_property_content_url",
# Metadata.
"metadata_bad_type",
"metadata_missing_property_name",
# ML field.
"mlfield_bad_source",
"mlfield_bad_type",
"mlfield_missing_property_name",
"mlfield_missing_source",
# Record set.
"recordset_bad_type",
"recordset_missing_context_for_datatype",
"recordset_missing_property_name",
"recordset_wrong_join",
],
)
Expand All @@ -51,6 +47,23 @@ def test_static_analysis(version, folder):
assert str(error_info.value) == get_error_msg(base_path / folder)


# These tests refer to properties which were mandatory for Croissant 0.8, but not 1.0.
@pytest.mark.parametrize(
"folder",
[
"distribution_missing_name",
"metadata_missing_property_name",
"mlfield_missing_property_name",
"recordset_missing_property_name",
],
)
def test_static_analysis_0_8(folder):
base_path = epath.Path(__file__).parent / "tests/graphs" / "0.8"
with pytest.raises(ValidationError) as error_info:
datasets.Dataset(base_path / f"{folder}/metadata.json")
assert str(error_info.value) == get_error_msg(base_path / folder)


def load_records_and_test_equality(
version: str, dataset_name: str, record_set_name: str, num_records: int
):
Expand Down
19 changes: 12 additions & 7 deletions python/mlcroissant/mlcroissant/_src/structure_graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from mlcroissant._src.core.types import Json
from mlcroissant._src.core.uuid import generate_uuid

ID_REGEX = "[a-zA-Z0-9\\-_\\.]+"
_MAX_ID_LENGTH = 255
NAME_REGEX = "[a-zA-Z0-9\\-_\\.]+"
_MAX_NAME_LENGTH = 255


@dataclasses.dataclass(eq=False, repr=False)
Expand Down Expand Up @@ -44,7 +44,8 @@ class Node(abc.ABC):

def __post_init__(self):
"""Checks for common properties between all nodes."""
self.assert_has_mandatory_properties("name", "id")
uuid_field = "name" if self.ctx.is_v0() else "id"
self.assert_has_mandatory_properties(uuid_field)

def assert_has_mandatory_properties(self, *mandatory_properties: str):
"""Checks a node in the graph for existing properties with constraints.
Expand Down Expand Up @@ -222,13 +223,17 @@ def validate_name(self):
return
if not name:
# This case is already checked for in every node's __post_init__ as `name`
# is a mandatory parameter.
# is a mandatory parameter for Croissant 0.8
return
if len(name) > _MAX_ID_LENGTH:
# For Croissant >= 1.0 compliant datasets, we don't enforce any more constraints
# on names.
if not self.ctx.is_v0():
return
if len(name) > _MAX_NAME_LENGTH:
self.add_error(
f'The name "{name}" is too long (>{_MAX_ID_LENGTH} characters).'
f'The name "{name}" is too long (>{_MAX_NAME_LENGTH} characters).'
)
regex = re.compile(rf"^{ID_REGEX}$")
regex = re.compile(rf"^{NAME_REGEX}$")
if not regex.match(name):
self.add_error(f'The name "{name}" contains forbidden characters.')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pytest

from mlcroissant._src.core.context import Context
from mlcroissant._src.core.context import CroissantVersion
from mlcroissant._src.structure_graph import base_node
from mlcroissant._src.tests.nodes import create_test_node

Expand Down Expand Up @@ -33,40 +35,56 @@ def test_there_exists_at_least_one_property():


@pytest.mark.parametrize(
["name", "expected_errors"],
["name", "expected_errors", "conforms_to"],
[
[
"a-regular-id",
[],
],
[
"a" * 256,
[
"The name"
' "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"'
" is too long (>255 characters)."
],
CroissantVersion.V_0_8,
],
[
"this is not valid",
['The name "this is not valid" contains forbidden characters.'],
CroissantVersion.V_0_8,
],
[
{"not": {"a": {"string"}}},
["The name should be a string. Got: <class 'dict'>."],
CroissantVersion.V_1_0,
],
],
)
def test_validate_name(name, expected_errors):
node = create_test_node(
Node,
name=name,
)
def test_validate_name(name, expected_errors, conforms_to):
node = create_test_node(Node, name=name, ctx=Context(conforms_to=conforms_to))
node.validate_name()
assert node.ctx.issues.errors
for expected_error, error in zip(expected_errors, node.ctx.issues.errors):
assert expected_error in error


@pytest.mark.parametrize(
"conforms_to", [CroissantVersion.V_0_8, CroissantVersion.V_1_0]
)
def test_validate_correct_name(conforms_to):
node = create_test_node(
Node, name="a-regular-id", ctx=Context(conforms_to=conforms_to)
)
node.validate_name()
assert not node.ctx.issues.errors


def test_validate_name_1_0():
node = create_test_node(
Node, name="this is not valid", ctx=Context(conforms_to=CroissantVersion.V_1_0)
)
node.validate_name()
assert not node.ctx.issues.errors


def test_eq():
node1 = create_test_node(Node, id="node1", name="node1")
node2 = create_test_node(Node, id="node2", name="node2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ class Field(Node):

def __post_init__(self):
"""Checks arguments of the node."""
uuid_field = "name" if self.ctx.is_v0() else "id"
self.validate_name()
self.assert_has_mandatory_properties("name", "id")
self.assert_has_mandatory_properties(uuid_field)
self.assert_has_optional_properties("description")
self.source.check_source(self.add_error)
self._standardize_data_types()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,33 @@

from unittest import mock

import pytest
from rdflib import term

from mlcroissant._src.core.constants import DataType
from mlcroissant._src.core.context import Context
from mlcroissant._src.core.context import CroissantVersion
from mlcroissant._src.structure_graph.base_node import Node
from mlcroissant._src.structure_graph.nodes.field import Field
from mlcroissant._src.tests.nodes import create_test_field
from mlcroissant._src.tests.nodes import create_test_node


def test_checks_are_performed():
@pytest.mark.parametrize(
["conforms_to", "field_uuid"],
[[CroissantVersion.V_0_8, "name"], [CroissantVersion.V_1_0, "id"]],
)
def test_checks_are_performed(conforms_to, field_uuid):
with mock.patch.object(
Node, "assert_has_mandatory_properties"
) as mandatory_mock, mock.patch.object(
Node, "assert_has_optional_properties"
) as optional_mock, mock.patch.object(
Node, "validate_name"
) as validate_name_mock:
create_test_node(Field)
mandatory_mock.assert_called_once_with("name", "id")
ctx = Context(conforms_to=conforms_to)
create_test_node(Field, ctx=ctx)
mandatory_mock.assert_called_once_with(field_uuid)
optional_mock.assert_called_once_with("description")
validate_name_mock.assert_called_once()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class FileObject(Node):
def __post_init__(self):
"""Checks arguments of the node."""
self.validate_name()
self.assert_has_mandatory_properties("encoding_format", "name", "id")
uuid_field = "name" if self.ctx.is_v0() else "id"
self.assert_has_mandatory_properties("encoding_format", uuid_field)

if not self.contained_in:
self.assert_has_mandatory_properties("content_url")
if self.ctx and not self.ctx.is_live_dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

from mlcroissant._src.core import constants
from mlcroissant._src.core.context import Context
from mlcroissant._src.core.context import CroissantVersion
from mlcroissant._src.structure_graph.base_node import Node
from mlcroissant._src.structure_graph.nodes.file_object import FileObject
from mlcroissant._src.tests.nodes import create_test_node


def test_checks_are_performed():
@pytest.mark.parametrize(
["conforms_to", "field_uuid"],
[[CroissantVersion.V_0_8, "name"], [CroissantVersion.V_1_0, "id"]],
)
def test_checks_are_performed(conforms_to, field_uuid):
with mock.patch.object(
Node, "assert_has_mandatory_properties"
) as mandatory_mock, mock.patch.object(
Expand All @@ -22,15 +27,20 @@ def test_checks_are_performed():
) as validate_name_mock, mock.patch.object(
Node, "assert_has_exclusive_properties"
) as exclusive_mock:
create_test_node(FileObject)
ctx = Context(conforms_to=conforms_to)
create_test_node(FileObject, ctx=ctx)
mandatory_mock.assert_has_calls([
mock.call("encoding_format", "name", "id"), mock.call("content_url")
mock.call("encoding_format", field_uuid), mock.call("content_url")
])
exclusive_mock.assert_called_once_with(["md5", "sha256"])
validate_name_mock.assert_called_once()


def test_checks_not_performed_for_live_dataset():
@pytest.mark.parametrize(
["conforms_to", "field_uuid"],
[[CroissantVersion.V_0_8, "name"], [CroissantVersion.V_1_0, "id"]],
)
def test_checks_not_performed_for_live_dataset(conforms_to, field_uuid):
with mock.patch.object(
Node, "assert_has_mandatory_properties"
) as mandatory_mock, mock.patch.object(
Expand All @@ -40,10 +50,10 @@ def test_checks_not_performed_for_live_dataset():
) as validate_name_mock, mock.patch.object(
Node, "assert_has_exclusive_properties"
) as exclusive_mock:
ctx = Context(is_live_dataset=True)
ctx = Context(is_live_dataset=True, conforms_to=conforms_to)
create_test_node(FileObject, ctx=ctx)
mandatory_mock.assert_has_calls([
mock.call("encoding_format", "name", "id"), mock.call("content_url")
mock.call("encoding_format", field_uuid), mock.call("content_url")
])
exclusive_mock.assert_not_called()
validate_name_mock.assert_called_once()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ class FileSet(Node):

def __post_init__(self):
"""Checks arguments of the node."""
uuid_field = "name" if self.ctx.is_v0() else "id"
self.validate_name()
self.assert_has_mandatory_properties(
"includes", "encoding_format", "name", "id"
)
self.assert_has_mandatory_properties("includes", "encoding_format", uuid_field)

def to_json(self) -> Json:
"""Converts the `FileSet` to JSON."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,34 @@

from unittest import mock

import pytest

from mlcroissant._src.core import constants
from mlcroissant._src.core.context import Context
from mlcroissant._src.core.context import CroissantVersion
from mlcroissant._src.core.uuid import formatted_uuid_to_json
from mlcroissant._src.structure_graph.base_node import Node
from mlcroissant._src.structure_graph.nodes.file_set import FileSet
from mlcroissant._src.tests.nodes import create_test_node
from mlcroissant._src.tests.versions import parametrize_conforms_to


def test_checks_are_performed():
@pytest.mark.parametrize(
["conforms_to", "field_uuid"],
[[CroissantVersion.V_0_8, "name"], [CroissantVersion.V_1_0, "id"]],
)
def test_checks_are_performed(conforms_to, field_uuid):
with mock.patch.object(
Node, "assert_has_mandatory_properties"
) as mandatory_mock, mock.patch.object(
Node, "assert_has_optional_properties"
) as optional_mock, mock.patch.object(
Node, "validate_name"
) as validate_name_mock:
create_test_node(FileSet)
ctx = Context(conforms_to=conforms_to)
create_test_node(FileSet, ctx=ctx)
mandatory_mock.assert_called_once_with(
"includes", "encoding_format", "name", "id"
"includes", "encoding_format", field_uuid
)
optional_mock.assert_not_called()
validate_name_mock.assert_called_once()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __post_init__(self):
self.date_created = self.validate_date(self.date_created)
self.date_modified = self.validate_date(self.date_modified)
self.date_published = self.validate_date(self.date_published)
self.assert_has_mandatory_properties("name")
if self.ctx.is_v0():
self.assert_has_mandatory_properties("name")
self.assert_has_optional_properties(
"cite_as", "date_published", "license", "version"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_checks_are_performed():
) as optional_mock, mock.patch.object(
Node, "validate_name"
) as validate_name_mock:
create_test_node(Metadata, name="field_name")
ctx = Context(conforms_to=CroissantVersion.V_0_8)
create_test_node(Metadata, name="field_name", id="field_id", ctx=ctx)
mandatory_mock.assert_called_once_with("name")
optional_mock.assert_called_once_with(
"cite_as", "date_published", "license", "version"
Expand Down Expand Up @@ -120,13 +121,16 @@ def test_valid_version(version, expected_version):

def test_issues_in_metadata_are_shared_with_children():
with pytest.raises(ValidationError, match="is mandatory, but does not exist"):
ctx = Context(conforms_to=CroissantVersion.V_0_8)
Metadata(
name="name",
description="description",
url="https://mlcommons.org",
version="1.0.0",
# We did not specify the RecordSet's name. Hence the exception above:
record_sets=[RecordSet(id="record-set", description="description")],
record_sets=[
RecordSet(id="record-set", description="description", ctx=ctx)
],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class RecordSet(Node):

def __post_init__(self):
"""Checks arguments of the node."""
uuid_field = "name" if self.ctx.is_v0() else "id"
self.validate_name()
self.assert_has_mandatory_properties("name", "id")
self.assert_has_mandatory_properties(uuid_field)
self.assert_has_optional_properties("description")

if self.data is not None:
Expand Down
Loading

0 comments on commit 39eb1dc

Please sign in to comment.