Skip to content

Commit

Permalink
feat(datatypes): add support for fixed length arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 28, 2025
1 parent 02628e8 commit eb88844
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/duckdb/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
("UUID", dt.uuid),
("VARCHAR", dt.string),
("INTEGER[]", dt.Array(dt.int32)),
("INTEGER[3]", dt.Array(dt.int32)),
("INTEGER[3]", dt.Array(dt.int32, length=3)),
("MAP(VARCHAR, BIGINT)", dt.Map(dt.string, dt.int64)),
(
"STRUCT(a INTEGER, b VARCHAR, c MAP(VARCHAR, DOUBLE[])[])",
Expand Down
29 changes: 29 additions & 0 deletions ibis/backends/duckdb/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ibis
from ibis import udf
from ibis.util import gen_name

Check warning on line 9 in ibis/backends/duckdb/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_udf.py#L9

Added line #L9 was not covered by tests


@udf.scalar.builtin
Expand Down Expand Up @@ -149,3 +150,31 @@ def regexp_extract(s, pattern, groups): ...
e = regexp_extract("2023-04-15", r"(\d+)-(\d+)-(\d+)", ["y", "m", "d"])
sql = str(ibis.to_sql(e, dialect="duckdb"))
assert r"REGEXP_EXTRACT('2023-04-15', '(\d+)-(\d+)-(\d+)', ['y', 'm', 'd'])" in sql


@pytest.fixture(scope="module")
def array_cosine_t(con):
return con.create_table(

Check warning on line 157 in ibis/backends/duckdb/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_udf.py#L155-L157

Added lines #L155 - L157 were not covered by tests
gen_name("array_cosine_t"),
obj={"fixed": [[1, 2, 3]], "varlen": [[1, 2, 3]]},
schema={"fixed": "array<double, 3>", "varlen": "array<double>"},
temp=True,
)


@pytest.mark.parametrize(

Check warning on line 165 in ibis/backends/duckdb/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_udf.py#L165

Added line #L165 was not covered by tests
("column", "expr_fn"),
[
("fixed", lambda c: c),
("varlen", lambda c: c.cast("array<float, 3>")),
],
ids=["no-cast", "cast"],
)
def test_builtin_fixed_length_array_udf(array_cosine_t, column, expr_fn):
@udf.scalar.builtin

Check warning on line 174 in ibis/backends/duckdb/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_udf.py#L173-L174

Added lines #L173 - L174 were not covered by tests
def array_cosine_similarity(a, b) -> float: ...

expr = expr_fn(array_cosine_t[column])
expr = array_cosine_similarity(expr, expr)
result = expr.execute()
assert result.iat[0] == 1.0

Check warning on line 180 in ibis/backends/duckdb/tests/test_udf.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/duckdb/tests/test_udf.py#L177-L180

Added lines #L177 - L180 were not covered by tests
46 changes: 39 additions & 7 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,14 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
"nullable", nullable if nullable is not None else cls.default_nullable
)
if method := getattr(cls, f"_from_sqlglot_{typecode.name}", None):
dtype = method(*typ.expressions, nullable=nullable)
if typecode == sge.DataType.Type.ARRAY:
dtype = method(
*typ.expressions,
*(typ.args.get("values", ()) or ()),
nullable=nullable,
)
else:
dtype = method(*typ.expressions, nullable=nullable)
elif (known_typ := _from_sqlglot_types.get(typecode)) is not None:
dtype = known_typ(nullable=nullable)
else:
Expand Down Expand Up @@ -222,9 +229,16 @@ def to_string(cls, dtype: dt.DataType) -> str:

@classmethod
def _from_sqlglot_ARRAY(
cls, value_type: sge.DataType, nullable: bool | None = None
cls,
value_type: sge.DataType,
length: sge.Literal | None = None,
nullable: bool | None = None,
) -> dt.Array:
return dt.Array(cls.to_ibis(value_type), nullable=nullable)
return dt.Array(
cls.to_ibis(value_type),
length=None if length is None else int(length.this),
nullable=nullable,
)

@classmethod
def _from_sqlglot_MAP(
Expand Down Expand Up @@ -380,7 +394,12 @@ def _from_ibis_Interval(cls, dtype: dt.Interval) -> sge.DataType:
@classmethod
def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType:
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.ARRAY, expressions=[value_type], nested=True)
return sge.DataType(
this=typecode.ARRAY,
expressions=[value_type],
values=None if dtype.length is None else [sge.convert(dtype.length)],
nested=True,
)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
Expand Down Expand Up @@ -775,7 +794,10 @@ def _from_sqlglot_DECIMAL(

@classmethod
def _from_sqlglot_ARRAY(
cls, value_type=None, nullable: bool | None = None
cls,
value_type: sge.DataType | None = None,
length: sge.Literal | None = None,
nullable: bool | None = None,
) -> dt.Array:
assert value_type is None
return dt.Array(dt.json, nullable=nullable)
Expand Down Expand Up @@ -1050,7 +1072,12 @@ def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
return sge.DataType(this=code)

@classmethod
def _from_sqlglot_ARRAY(cls, value_type: sge.DataType) -> NoReturn:
def _from_sqlglot_ARRAY(
cls,
value_type: sge.DataType,
length: sge.Literal | None = None,
nullable: bool | None = None,
) -> NoReturn:
raise com.UnsupportedBackendType("Arrays not supported in Exasol")

@classmethod
Expand Down Expand Up @@ -1105,7 +1132,12 @@ def _from_ibis_Struct(cls, dtype: dt.String) -> sge.DataType:
raise com.UnsupportedBackendType("SQL Server does not support structs")

@classmethod
def _from_sqlglot_ARRAY(cls) -> sge.DataType:
def _from_sqlglot_ARRAY(
cls,
value_type: sge.DataType,
length: sge.Literal | None = None,
nullable: bool | None = None,
) -> NoReturn:
raise com.UnsupportedBackendType("SQL Server does not support arrays")

@classmethod
Expand Down
21 changes: 17 additions & 4 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
from numbers import Integral, Real
from typing import (
Annotated,
Any,
Generic,
Literal,
Expand All @@ -27,7 +28,7 @@
from ibis.common.collections import FrozenOrderedDict, MapSet
from ibis.common.dispatch import lazy_singledispatch
from ibis.common.grounds import Concrete, Singleton
from ibis.common.patterns import Coercible, CoercionError
from ibis.common.patterns import Between, Coercible, CoercionError
from ibis.common.temporal import IntervalUnit, TimestampUnit


Expand All @@ -50,14 +51,14 @@ def dtype(value: Any, nullable: bool = True) -> DataType:
>>> ibis.dtype("int32")
Int32(nullable=True)
>>> ibis.dtype("array<float>")
Array(value_type=Float64(nullable=True), nullable=True)
Array(value_type=Float64(nullable=True), length=None, nullable=True)
DataType objects may also be created from Python types:
>>> ibis.dtype(int)
Int64(nullable=True)
>>> ibis.dtype(list[float])
Array(value_type=Float64(nullable=True), nullable=True)
Array(value_type=Float64(nullable=True), length=None, nullable=True)
Or other type systems, like numpy/pandas/pyarrow types:
Expand Down Expand Up @@ -309,6 +310,10 @@ def is_enum(self) -> bool:
"""Return True if an instance of an Enum type."""
return isinstance(self, Enum)

def is_fixed_length_array(self) -> bool:
"""Return True if an instance of an Array type and has a known length."""
return isinstance(self, Array) and self.length is not None

def is_float16(self) -> bool:
"""Return True if an instance of a Float16 type."""
return isinstance(self, Float16)
Expand Down Expand Up @@ -904,13 +909,19 @@ class Array(Variadic, Parametric, Generic[T]):
"""Array values."""

value_type: T
"""Element type of the array."""
length: Annotated[int, Between(lower=0)] | None = None
"""The length of the array if known."""

scalar = "ArrayScalar"
column = "ArrayColumn"

@property
def _pretty_piece(self) -> str:
return f"<{self.value_type}>"
value_type = self.value_type
if (length := self.length) is not None:
return f"<{value_type}, {length:d}>"
return f"<{value_type}>"


K = TypeVar("K", bound=DataType, covariant=True)
Expand All @@ -922,7 +933,9 @@ class Map(Variadic, Parametric, Generic[K, V]):
"""Associative array values."""

key_type: K
"""Map key type."""
value_type: V
"""Map value type."""

scalar = "MapScalar"
column = "MapColumn"
Expand Down
11 changes: 8 additions & 3 deletions ibis/expr/datatypes/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def parse(
>>> import ibis
>>> import ibis.expr.datatypes as dt
>>> dt.parse("array<int64>")
Array(value_type=Int64(nullable=True), nullable=True)
Array(value_type=Int64(nullable=True), length=None, nullable=True)
You can avoid parsing altogether by constructing objects directly
Expand Down Expand Up @@ -182,8 +182,13 @@ def geotype_parser(typ: type[dt.DataType]) -> dt.DataType:
)

ty = parsy.forward_declaration()
angle_type = LANGLE.then(ty).skip(RANGLE)
array = spaceless_string("array").then(angle_type).map(dt.Array)

array = (
spaceless_string("array")
.then(LANGLE)
.then(parsy.seq(ty, COMMA.then(LENGTH).optional()).combine(dt.Array))
.skip(RANGLE)
)

map = (
spaceless_string("map")
Expand Down
8 changes: 8 additions & 0 deletions ibis/expr/datatypes/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,14 @@ def test_is_methods(dtype_class):
assert is_dtype is True


def test_is_fixed_length_array():
dtype = dt.Array(dt.int8)
assert not dtype.is_fixed_length_array()

dtype = dt.Array(dt.int8, 10)
assert dtype.is_fixed_length_array()


def test_is_array():
assert dt.Array(dt.string).is_array()
assert not dt.string.is_array()
Expand Down
5 changes: 3 additions & 2 deletions ibis/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def primitive_dtypes(nullable=_nullable):


def array_dtypes(value_type=_item_strategy, nullable=_nullable):
return st.builds(dt.Array, value_type=value_type, nullable=nullable)
length = st.one_of(st.none(), st.integers(min_value=0))
return st.builds(dt.Array, value_type=value_type, nullable=nullable, length=length)


def map_dtypes(key_type=_item_strategy, value_type=_item_strategy, nullable=_nullable):
Expand Down Expand Up @@ -180,7 +181,7 @@ def struct_dtypes(

def geospatial_dtypes(nullable=_nullable):
geotype = st.one_of(st.just("geography"), st.just("geometry"))
srid = st.one_of(st.just(None), st.integers(min_value=0))
srid = st.one_of(st.none(), st.integers(min_value=0))
return st.one_of(
st.builds(dt.Point, geotype=geotype, nullable=nullable, srid=srid),
st.builds(dt.LineString, geotype=geotype, nullable=nullable, srid=srid),
Expand Down

0 comments on commit eb88844

Please sign in to comment.