diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py index 13f4ec1e8..5503a3486 100644 --- a/merlin/dtypes/shape.py +++ b/merlin/dtypes/shape.py @@ -20,12 +20,41 @@ class DefaultShapes(Enum): - LIST = (None, None) - SCALAR = (None,) + LIST = (-1, None) + SCALAR = (-1,) + + +def Dimension(size=None, index=None): + """Create a dimension from a size. + + A size can be one of: + - None : a ragged dimension of unknown size + - int : a fixed dimension of some size (-1 = unknown) + - 2-tuple : the bounds of a ragged dimension (fixed if min == max) + """ + if isinstance(size, (UniformDimension, RaggedDimension)): + return size + elif isinstance(size, tuple) and len(size) == 2: + if size[0] == size[1] or index == 0: + return UniformDimension(size[0], size[1]) + return RaggedDimension(size[0], size[1]) + elif isinstance(size, int): + if size == -1: + return UniformDimension() + return UniformDimension(size, size) + elif size is None: + if index == 0: + return UniformDimension() + return RaggedDimension() + else: + raise ValueError( + f"Invalid dimension format: {size}. Each dimension is expected " + " to be None, a single integer, or a tuple with length 2." + ) @dataclass(frozen=True) -class Dimension: +class BaseDimension: """ The range of potential sizes for a single dimension of a field or column """ @@ -37,6 +66,12 @@ def __post_init__(self): if self.min is None: raise ValueError("The minimum size of a dimension cannot be None. ") + if not isinstance(self.min, int): + raise ValueError("The minimmum size must be an integer. " f"Provided min: {self.min}") + + if self.max and not isinstance(self.max, int): + raise ValueError("The maximum size must be an integer. " f"Provided max: {self.max}") + if self.min < 0: raise ValueError( "The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}" @@ -72,14 +107,17 @@ def __int__(self): @property def is_bounded(self): + """Is the dimension bounded in size?""" return self.max is not None @property def is_fixed(self): + """Is the dimension fixed in size?""" return self.is_bounded and self.min == self.max @property def is_variable(self): + """Can the size of the dimension vary between instances of tensors.""" return not self.is_fixed @property @@ -93,6 +131,37 @@ def with_max(self, value): return replace(self, max=value) +class RaggedDimension(BaseDimension): + @property + def is_uniform(self): + return False + + @property + def is_ragged(self): + return True + + @property + def size(self): + return None + + +class UniformDimension(BaseDimension): + @property + def is_uniform(self): + return True + + @property + def is_ragged(self): + return False + + @property + def size(self): + if self.is_fixed: + return self.max + else: + return -1 + + @dataclass(frozen=True) class Shape: """ @@ -111,19 +180,7 @@ def __post_init__(self): if self.dims is not None: new_dims = [] for i, dim in enumerate(self.dims): - if isinstance(dim, Dimension): - new_dim = dim - elif isinstance(dim, tuple) and len(dim) == 2: - new_dim = Dimension(dim[0], dim[1]) - elif isinstance(dim, int): - new_dim = Dimension(dim, dim) - elif dim is None: - new_dim = Dimension() - else: - raise ValueError( - f"Invalid shape tuple format: {self.dims}. Each dimension is expected " - " to be None, a single integer, or a tuple with length 2." - ) + new_dim = Dimension(dim, index=i) new_dims.append(new_dim) object.__setattr__(self, "dims", tuple(new_dims)) @@ -155,10 +212,16 @@ def with_dim(self, index, value): return replace(self, dims=tuple(new_dims)) def with_dim_min(self, index, value): - return self.with_dim(index, self.dims[index].with_min(value)) + new_dim = self.dims[index].with_min(value) + if new_dim.is_uniform: + new_dim = Dimension(value) + return self.with_dim(index, new_dim) def with_dim_max(self, index, value): - return self.with_dim(index, self.dims[index].with_max(value)) + new_dim = self.dims[index].with_max(value) + if new_dim.is_uniform: + new_dim = Dimension(value) + return self.with_dim(index, new_dim) @property def min(self) -> Tuple: @@ -176,6 +239,10 @@ def is_bounded(self): def is_fixed(self): return all(dim.is_fixed for dim in self.dims) + @property + def is_uniform(self): + return all(dim.is_uniform for dim in self.dims) + @property def is_variable(self): return not self.is_fixed @@ -186,14 +253,16 @@ def is_list(self): @property def is_ragged(self): - return self.is_list and any(dim.min != dim.max for dim in self.dims[1:]) + return self.is_list and any(dim.is_ragged for dim in self.dims[1:]) @property def as_tuple(self): if not self.dims: return None - return tuple(((dim.min, dim.max) if dim.min != dim.max else dim.max for dim in self.dims)) + return tuple( + dim.size if dim.is_fixed or dim.is_uniform else (dim.min, dim.max) for dim in self.dims + ) @property def is_unknown(self): diff --git a/merlin/schema/io/tensorflow_metadata.py b/merlin/schema/io/tensorflow_metadata.py index 43f6e42c7..3361956f4 100644 --- a/merlin/schema/io/tensorflow_metadata.py +++ b/merlin/schema/io/tensorflow_metadata.py @@ -19,6 +19,7 @@ import fsspec import merlin.dtypes as md +from merlin.dtypes.shape import RaggedDimension, UniformDimension from merlin.schema.io import proto_utils, schema_bp from merlin.schema.io.schema_bp import Feature, FeatureType, FloatDomain, IntDomain from merlin.schema.io.schema_bp import Schema as ProtoSchema @@ -273,9 +274,15 @@ def _pb_extra_metadata(column_schema): properties = { k: v for k, v in column_schema.properties.items() if k not in ("domain", "value_count") } - properties["_dims"] = list( - list(dim) if isinstance(dim, tuple) else dim for dim in column_schema.shape.as_tuple or [] + _dims = ( + list( + {"min": dim.min, "max": dim.max, "is_uniform": dim.is_uniform} + for dim in column_schema.shape.dims + ) + if column_schema.shape.dims + else [] ) + properties["_dims"] = _dims properties["is_list"] = column_schema.is_list properties["is_ragged"] = column_schema.is_ragged if column_schema.dtype.element_size: @@ -423,10 +430,18 @@ def _merlin_dtype(feature, properties): for dim in dims_list: if isinstance(dim, list): dims.append(tuple(int(d) if isinstance(d, float) else d for d in dim)) - elif dim is not None: + elif isinstance(dim, (int, float)): dims.append(int(dim)) - else: + elif dim is None: dims.append(dim) + elif isinstance(dim, dict): + _min = int(dim["min"]) if isinstance(dim["min"], float) else dim["min"] + _max = int(dim["max"]) if isinstance(dim["max"], float) else dim["max"] + if dim["is_uniform"]: + dims.append(UniformDimension(_min, _max)) + else: + dims.append(RaggedDimension(_min, _max)) + dtype = dtype.with_shape(tuple(dims)) # If we found dims, avoid overwriting that shape with one inferred from counts or flags @@ -452,10 +467,8 @@ def _merlin_column(feature): if Tags.CATEGORICAL not in tags: tags.append(Tags.CATEGORICAL) - dims = dtype.shape.as_tuple - - if dims: - return ColumnSchema(name, tags, properties, dtype, dims=dims) + if dtype.shape.dims: + return ColumnSchema(name, tags, properties, dtype, dims=dtype.shape.dims) else: return ColumnSchema(name, tags, properties, dtype, is_list=is_list, is_ragged=is_ragged) diff --git a/merlin/schema/schema.py b/merlin/schema/schema.py index 98299ebcf..74cb597df 100644 --- a/merlin/schema/schema.py +++ b/merlin/schema/schema.py @@ -96,8 +96,10 @@ def __post_init__(self, dims): new_shape = dtype.shape elif value_counts: new_shape = self._shape_from_counts(Domain(**value_counts)) + elif self.is_list and self.is_ragged is False: + new_shape = Shape((-1, -1)) elif self.is_list: - new_shape = self._shape_from_flags(self.is_list) + new_shape = Shape((-1, None)) else: new_shape = Shape() @@ -115,11 +117,8 @@ def __post_init__(self, dims): object.__setattr__(self, "properties", properties) - def _shape_from_flags(self, is_list): - return Shape(((0, None), (0, None))) if is_list else None - def _shape_from_counts(self, value_count): - return Shape(((0, None), (value_count.min or 0, value_count.max))) + return Shape((-1, (value_count.min or 0, value_count.max))) @property def shape(self): diff --git a/tests/unit/dtypes/test_shape.py b/tests/unit/dtypes/test_shape.py index 5dd9a164a..0bbeed068 100644 --- a/tests/unit/dtypes/test_shape.py +++ b/tests/unit/dtypes/test_shape.py @@ -28,7 +28,7 @@ def test_empty_dimension(): def test_min_max_val_dimension(): - dim = Dimension(2, 3) + dim = Dimension((2, 3)) assert dim.min == 2 assert dim.max == 3 @@ -36,32 +36,29 @@ def test_min_max_val_dimension(): def test_fixed_min_with_unbounded_max(): dim = Dimension(2) assert dim.min == 2 - assert dim.max is None + assert dim.max == 2 - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.min == 2 assert dim.max is None def test_min_is_none_raises_error(): with pytest.raises(ValueError): - Dimension(None) - - with pytest.raises(ValueError): - Dimension(None, 1) + Dimension((None, 1)) def test_bounds_must_be_non_negative(): with pytest.raises(ValueError): - Dimension(-1, 2) + Dimension((-1, 2)) with pytest.raises(ValueError): - Dimension(2, -1) + Dimension((2, -1)) def test_max_less_than_min(): with pytest.raises(ValueError): - Dimension(2, 1) + Dimension((2, 1)) def test_is_bounded(): @@ -69,15 +66,15 @@ def test_is_bounded(): assert dim.is_bounded is False dim = Dimension(2) - assert dim.is_bounded is False + assert dim.is_bounded is True - dim = Dimension(2, 2) + dim = Dimension((2, 2)) assert dim.is_bounded is True - dim = Dimension(2, 4) + dim = Dimension((2, 4)) assert dim.is_bounded is True - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.is_bounded is False @@ -86,15 +83,15 @@ def test_is_fixed(): assert dim.is_fixed is False dim = Dimension(2) - assert dim.is_fixed is False + assert dim.is_fixed is True - dim = Dimension(2, 2) + dim = Dimension((2, 2)) assert dim.is_fixed is True - dim = Dimension(2, 4) + dim = Dimension((2, 4)) assert dim.is_fixed is False - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.is_fixed is False @@ -103,15 +100,15 @@ def test_is_variable(): assert dim.is_variable is True dim = Dimension(2) - assert dim.is_variable is True + assert dim.is_variable is False - dim = Dimension(2, 2) + dim = Dimension((2, 2)) assert dim.is_variable is False - dim = Dimension(2, 4) + dim = Dimension((2, 4)) assert dim.is_variable is True - dim = Dimension(2, None) + dim = Dimension((2, None)) assert dim.is_variable is True diff --git a/tests/unit/schema/test_column_schemas.py b/tests/unit/schema/test_column_schemas.py index daac73fdd..5b8bb4ebb 100644 --- a/tests/unit/schema/test_column_schemas.py +++ b/tests/unit/schema/test_column_schemas.py @@ -192,11 +192,9 @@ def test_list_column_attributes(): assert col3_schema.is_list assert col3_schema.is_ragged - # TODO: Re-enable this test case once we've addressed cases - # like this in downstream libraries - - # with pytest.raises(ValueError): - # ColumnSchema("col4", is_list=True, is_ragged=False) + col4_schema = ColumnSchema("col4", is_list=True, is_ragged=False) + assert col4_schema.is_list + assert col4_schema.is_ragged is False with pytest.raises(ValueError): ColumnSchema("col5", is_list=False, is_ragged=True) @@ -257,18 +255,18 @@ def test_setting_partial_value_count(value_count): ) assert col_schema.is_list assert not col_schema.is_ragged - assert col_schema.shape == Shape((None, 10)) + assert col_schema.shape == Shape((-1, 10)) assert col_schema.properties["value_count"] == {"min": 10, "max": 10} def test_setting_value_counts_updates_shape_and_flags(): - col_schema = ColumnSchema("col", dims=(None,)) + col_schema = ColumnSchema("col", dims=(-1,)) counts = {"min": 4, "max": 5} updated_schema = col_schema.with_properties({"value_count": counts}) assert updated_schema.properties["value_count"] == counts - assert updated_schema.shape == Shape((None, (4, 5))) + assert updated_schema.shape == Shape((-1, (4, 5))) assert updated_schema.is_list assert updated_schema.is_ragged @@ -287,7 +285,7 @@ def test_setting_flags_updates_shape_and_value_counts(): col_schema = ColumnSchema("col") updated_schema = col_schema.with_dtype(md.int64, is_list=True, is_ragged=True) - assert updated_schema.shape == Shape((None, None)) + assert updated_schema.shape == Shape((-1, None)) assert updated_schema.properties["value_count"] == {"min": 0, "max": None} assert updated_schema.is_list assert updated_schema.is_ragged diff --git a/tests/unit/schema/test_schema_io.py b/tests/unit/schema/test_schema_io.py index 3770f4b37..618948bf6 100644 --- a/tests/unit/schema/test_schema_io.py +++ b/tests/unit/schema/test_schema_io.py @@ -205,7 +205,7 @@ def test_schema_with_shape_to_tensorflow_metadata_json(): ragged_dim = loaded_schema["col"].shape[1] assert isinstance(ragged_dim.max, int) assert isinstance(ragged_dim.min, int) - assert ragged_dim == Dimension(min=1, max=5) + assert ragged_dim == Dimension((1, 5)) def test_tensorflow_metadata_from_json():