diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bb44c8c..245ded4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,12 +7,14 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Install GIS libraries + run: sudo apt-get install -y binutils libproj-dev gdal-bin libsqlite3-mod-spatialite - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.gitignore b/.gitignore index 02ac43e..0676574 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,5 @@ ENV/ .serverless/ node_modules/ +tests/integration/django/db/* +tests/data/world_borders* diff --git a/odata_query/ast.py b/odata_query/ast.py index 7015334..0e268d5 100644 --- a/odata_query/ast.py +++ b/odata_query/ast.py @@ -19,6 +19,9 @@ class Identifier(_Node): name: str namespace: Tuple[str, ...] = field(default_factory=tuple) + def full_name(self): + return ".".join(self.namespace + (self.name,)) + @dataclass(frozen=True) class Attribute(_Node): @@ -81,6 +84,14 @@ def py_val(self) -> str: return self.val +@dataclass(frozen=True) +class Geography(_Literal): + val: str + + def wkt(self): + return self.val + + @dataclass(frozen=True) class Date(_Literal): val: str @@ -309,6 +320,12 @@ class UnaryOp(_Node): ############################################################################### # Function calls ############################################################################### +@dataclass(frozen=True) +class NamedParam(_Node): + name: Identifier + param: _Node + + @dataclass(frozen=True) class Call(_Node): func: Identifier diff --git a/odata_query/django/django_q.py b/odata_query/django/django_q.py index d923003..ab0a29e 100644 --- a/odata_query/django/django_q.py +++ b/odata_query/django/django_q.py @@ -1,4 +1,5 @@ import operator +from functools import wraps from typing import Any, Callable, Dict, List, Optional, Type, Union from uuid import UUID @@ -17,6 +18,17 @@ ) from django.db.models.expressions import Expression +try: + # Django gis requires system level libraries, which not every user needs. + from django.contrib.gis.db.models import functions as gis_functions + from django.contrib.gis.geos import GEOSGeometry + + _gis_error = None +except Exception as e: + gis_functions = None + GEOSGeometry = None + _gis_error = e + from odata_query import ast, exceptions as ex, typing, utils, visitor from .django_q_ext import NotEqual @@ -40,6 +52,18 @@ } +def requires_gis(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not gis_functions: + raise ImportError( + "Cannot use geography functions because GeoDjango failed to load." + ) from _gis_error + return func(*args, **kwargs) + + return wrapper + + class AstToDjangoQVisitor(visitor.NodeVisitor): """ :class:`NodeVisitor` that transforms an :term:`AST` into a Django Q @@ -254,12 +278,23 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> str: def visit_Call(self, node: ast.Call) -> Union[Expression, Q]: ":meta private:" + + func_name = node.func.full_name().replace(".", "__") + try: - q_gen = getattr(self, "djangofunc_" + node.func.name.lower()) + q_gen = getattr(self, "djangofunc_" + func_name.lower()) except AttributeError: - raise ex.UnsupportedFunctionException(node.func.name) + raise ex.UnsupportedFunctionException(func_name) + + args = [] + kwargs = {} + for arg in node.args: + if isinstance(arg, ast.NamedParam): + kwargs[arg.name.name] = arg.param + else: + args.append(arg) - res = q_gen(*node.args) + res = q_gen(*args, **kwargs) return res def visit_CollectionLambda(self, node: ast.CollectionLambda) -> Q: @@ -303,6 +338,20 @@ def visit_CollectionLambda(self, node: ast.CollectionLambda) -> Q: else: raise NotImplementedError() + @requires_gis + def djangofunc_geo__intersects( + self, field: ast.Identifier, geography: ast.Geography + ): + return Q(**{field.name + "__" + "intersects": GEOSGeometry(geography.wkt())}) + + @requires_gis + def djangofunc_geo__distance(self, field: ast.Identifier, geography: ast.Geography): + return gis_functions.Distance(field.name, GEOSGeometry(geography.wkt())) + + @requires_gis + def djangofunc_geo__length(self, field: ast.Identifier): + return gis_functions.Length(field.name) + def djangofunc_contains( self, field: ast._Node, substr: ast._Node ) -> lookups.Contains: diff --git a/odata_query/grammar.py b/odata_query/grammar.py index 92c30b0..3f88cbc 100644 --- a/odata_query/grammar.py +++ b/odata_query/grammar.py @@ -53,7 +53,7 @@ "floor": 1, "ceiling": 1, # Geo functions - "geo.distance": 1, + "geo.distance": 2, "geo.length": 1, "geo.intersects": 2, # Set functions @@ -67,6 +67,7 @@ class ODataLexer(Lexer): "ODATA_IDENTIFIER", "NULL", "STRING", + "GEOGRAPHY", "GUID", "DATETIME", "DATE", @@ -95,7 +96,7 @@ class ODataLexer(Lexer): "ALL", "WS", } - literals = {"(", ")", ",", "/", ":"} + literals = {"(", ")", ",", "/", ":", "="} reflags = re.I # Ensure MyPy doesn't lose its mind: @@ -143,6 +144,13 @@ def STRING(self, t): t.value = ast.String(val) return t + @_(r"geography'(?:[^']|'')*'") + def GEOGRAPHY(self, t): + ":meta private:" + + t.value = ast.Geography(t.value[10:-1]) + return t + @_(r"[\da-f]{8}-[\da-f]{4}-[\da-f]{4}-[\da-f]{4}-[\da-f]{12}") def GUID(self, t): ":meta private:" @@ -375,6 +383,7 @@ def common_expr(self, p): "INTEGER", "DECIMAL", "STRING", + "GEOGRAPHY", "BOOLEAN", "GUID", "DATE", @@ -551,24 +560,27 @@ def common_expr(self, p): #################################################################################### def _function_call(self, func: ast.Identifier, args: List[ast._Node]): ":meta private:" - func_name = func.name - try: - n_args_exp = ODATA_FUNCTIONS[func_name] - except KeyError: - raise exceptions.UnknownFunctionException(func_name) - - n_args_given = len(args) - if isinstance(n_args_exp, int) and n_args_given != n_args_exp: - raise exceptions.ArgumentCountException( - func_name, n_args_exp, n_args_exp, n_args_given - ) - - if isinstance(n_args_exp, tuple) and ( - n_args_given < n_args_exp[0] or n_args_given > n_args_exp[1] - ): - raise exceptions.ArgumentCountException( - func_name, n_args_exp[0], n_args_exp[1], n_args_given - ) + + func_name = func.full_name() + + if func.namespace in ((), ("geo",)): + try: + n_args_exp = ODATA_FUNCTIONS[func_name] + except KeyError: + raise exceptions.UnknownFunctionException(func_name) + + n_args_given = len(args) + if isinstance(n_args_exp, int) and n_args_given != n_args_exp: + raise exceptions.ArgumentCountException( + func_name, n_args_exp, n_args_exp, n_args_given + ) + + if isinstance(n_args_exp, tuple) and ( + n_args_given < n_args_exp[0] or n_args_given > n_args_exp[1] + ): + raise exceptions.ArgumentCountException( + func_name, n_args_exp[0], n_args_exp[1], n_args_given + ) return ast.Call(func, args) @@ -590,6 +602,33 @@ def common_expr(self, p): args = p[1].val return self._function_call(p[0], args) + @_('ODATA_IDENTIFIER "=" common_expr') # type:ignore[no-redef] + def named_param(self, p): + ":meta private:" + return ast.NamedParam(p[0], p.common_expr) + + @_('named_param BWS "," BWS named_param') + def list_named_param(self, p): + ":meta private:" + return [p[0], p[4]] + + @_('list_named_param BWS "," BWS named_param') # type:ignore[no-redef] + def list_named_param(self, p): + ":meta private:" + return p.list_items + [p.named_param] + + @_('ODATA_IDENTIFIER "(" BWS named_param BWS ")"') # type:ignore[no-redef] + def common_expr(self, p): + ":meta private:" + args = [p.named_param] + return self._function_call(p[0], args) + + @_('ODATA_IDENTIFIER "(" BWS list_named_param BWS ")"') # type:ignore[no-redef] + def common_expr(self, p): + ":meta private:" + args = p.list_named_param + return self._function_call(p[0], args) + #################################################################################### # Misc #################################################################################### diff --git a/odata_query/typing.py b/odata_query/typing.py index c37c7f4..393bf8b 100644 --- a/odata_query/typing.py +++ b/odata_query/typing.py @@ -64,7 +64,7 @@ def infer_return_type(node: ast.Call) -> Optional[Type[ast._Node]]: Returns: The inferred type or ``None`` if unable to infer. """ - func = node.func.name + func = node.func.full_name() if func in ( "contains", diff --git a/tests/conftest.py b/tests/conftest.py index e98484f..4d708ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from odata_query.grammar import ODataLexer, ODataParser @@ -11,3 +13,11 @@ def lexer(): @pytest.fixture def parser(): return ODataParser() + + +@pytest.fixture(scope="session") +def data_dir(): + data_dir_path = Path(__file__).parent / "data" + data_dir_path.mkdir(exist_ok=True) + + return data_dir_path diff --git a/tests/integration/django/apps.py b/tests/integration/django/apps.py index a8ea17f..16c068d 100644 --- a/tests/integration/django/apps.py +++ b/tests/integration/django/apps.py @@ -5,3 +5,32 @@ class ODataQueryConfig(AppConfig): name = "tests.integration.django" verbose_name = "OData Query Django test app" default = True + + +class DbRouter: + """ + Ensure that GeoDjango models go to the SpatiaLite database, while other + models use the default SQLite database. + """ + + GEO_APP = "django_geo" + + def db_for_read(self, model, **hints): + if model._meta.app_label == self.GEO_APP: + return "geo" + return None + + def db_for_write(self, model, **hints): + if model._meta.app_label == self.GEO_APP: + return "geo" + return None + + def allow_relation(self, obj1, obj2, **hints): + return obj1._meta.app_label == obj2._meta.app_label + + def allow_migrate(self, db: str, app_label: str, model_name=None, **hints): + if app_label != self.GEO_APP and db == "default": + return True + if app_label == self.GEO_APP and db == "geo": + return True + return False diff --git a/tests/integration/django/settings.py b/tests/integration/django/settings.py index 8dec28e..ed9b09c 100644 --- a/tests/integration/django/settings.py +++ b/tests/integration/django/settings.py @@ -1,8 +1,23 @@ +from pathlib import Path + +DB_DIR = Path(__file__).parent / "db" +DB_DIR.mkdir(exist_ok=True) + DATABASES = { "default": { "ENGINE": "django.db.backends.sqlite3", - "NAME": "odata-query", - } + "NAME": str(DB_DIR / "odata-query"), + }, + "geo": { + "ENGINE": "django.contrib.gis.db.backends.spatialite", + "NAME": str(DB_DIR / "odata-query-geo"), + }, } +DATABASE_ROUTERS = ["tests.integration.django.apps.DbRouter"] DEBUG = True -INSTALLED_APPS = ["tests.integration.django.apps.ODataQueryConfig"] +INSTALLED_APPS = [ + "tests.integration.django.apps.ODataQueryConfig", + # GEO: + "django.contrib.gis", + "tests.integration.django_geo.apps.ODataQueryConfig", +] diff --git a/tests/integration/django_geo/__init__.py b/tests/integration/django_geo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/django_geo/apps.py b/tests/integration/django_geo/apps.py new file mode 100644 index 0000000..a705475 --- /dev/null +++ b/tests/integration/django_geo/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class ODataQueryConfig(AppConfig): + name = "tests.integration.django_geo" + verbose_name = "OData Query GeoDjango test app" + default = True diff --git a/tests/integration/django_geo/conftest.py b/tests/integration/django_geo/conftest.py new file mode 100644 index 0000000..bfd9f50 --- /dev/null +++ b/tests/integration/django_geo/conftest.py @@ -0,0 +1,36 @@ +import urllib.request as req +from pathlib import Path +from zipfile import ZipFile + +import pytest +from django.core import management + + +@pytest.fixture(scope="session") +def django_db(): + management.call_command("migrate", "--run-syncdb", "--database", "geo") + + +@pytest.fixture(scope="session") +def world_borders_dataset(data_dir: Path): + target_dir = data_dir / "world_borders" + + if target_dir.exists(): + return target_dir + + filename_zip = target_dir.with_suffix(".zip") + + if not filename_zip.exists(): + opener = req.build_opener() + opener.addheaders = [("Accept", "application/zip")] + req.install_opener(opener) + req.urlretrieve( + "https://thematicmapping.org/downloads/TM_WORLD_BORDERS-0.3.zip", + filename_zip, + ) + assert filename_zip.exists() + + with ZipFile(filename_zip, "r") as z: + z.extractall(target_dir) + + return target_dir diff --git a/tests/integration/django_geo/models.py b/tests/integration/django_geo/models.py new file mode 100644 index 0000000..b2135a0 --- /dev/null +++ b/tests/integration/django_geo/models.py @@ -0,0 +1,33 @@ +# https://docs.djangoproject.com/en/4.2/ref/contrib/gis/tutorial/#geographic-models + +from django.core.exceptions import ImproperlyConfigured +from django.db import models + +# This file needs to be importable even without Geo system libraries installed. +# Tests using these libraries will be skipped using pytest.skip +try: + from django.contrib.gis.db.models import MultiPolygonField +except (ImportError, ImproperlyConfigured): + MultiPolygonField = models.CharField + + +class WorldBorder(models.Model): + # Regular Django fields corresponding to the attributes in the + # world borders shapefile. + name = models.CharField(max_length=50) + area = models.IntegerField() + pop2005 = models.IntegerField("Population 2005") + fips = models.CharField("FIPS Code", max_length=2, null=True) + iso2 = models.CharField("2 Digit ISO", max_length=2) + iso3 = models.CharField("3 Digit ISO", max_length=3) + un = models.IntegerField("United Nations Code") + region = models.IntegerField("Region Code") + subregion = models.IntegerField("Sub-Region Code") + lon = models.FloatField() + lat = models.FloatField() + + # GeoDjango-specific: a geometry field (MultiPolygonField) + mpoly = MultiPolygonField() + + def __str__(self): + return self.name diff --git a/tests/integration/django_geo/test_querying.py b/tests/integration/django_geo/test_querying.py new file mode 100644 index 0000000..67f0261 --- /dev/null +++ b/tests/integration/django_geo/test_querying.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Type + +import pytest +from django.core.exceptions import ImproperlyConfigured + +from odata_query.django import apply_odata_query + +from .models import WorldBorder + +try: + from django.contrib.gis.db import models + from django.contrib.gis.utils import LayerMapping +except (ImportError, ImproperlyConfigured): + pytest.skip(allow_module_level=True, reason="Could not load GIS libraries") + +# The default spatial reference system for geometry fields is WGS84 +# (meaning the SRID is 4326) +SRID = "SRID=4326" + + +@pytest.fixture(scope="session") +def sample_data_sess(django_db, world_borders_dataset: Path): + world_mapping = { + "fips": "FIPS", + "iso2": "ISO2", + "iso3": "ISO3", + "un": "UN", + "name": "NAME", + "area": "AREA", + "pop2005": "POP2005", + "region": "REGION", + "subregion": "SUBREGION", + "lon": "LON", + "lat": "LAT", + "mpoly": "MULTIPOLYGON", + } + + world_shp = world_borders_dataset / "TM_WORLD_BORDERS-0.3.shp" + lm = LayerMapping(WorldBorder, world_shp, world_mapping, transform=False) + lm.save(strict=True, verbose=True) + yield + WorldBorder.objects.all().delete() + + +@pytest.mark.parametrize( + "model, query, exp_results", + [ + ( + WorldBorder, + "geo.length(mpoly) gt 1000000", + 154, + ), + ( + WorldBorder, + f"geo.intersects(mpoly, geography'{SRID};Point(-95.3385 29.7245)')", + 1, + ), + ], +) +def test_query_with_odata( + model: Type[models.Model], + query: str, + exp_results: int, + sample_data_sess, +): + q = apply_odata_query(model.objects, query) + results = q.all() + assert len(results) == exp_results diff --git a/tests/unit/test_odata_parser.py b/tests/unit/test_odata_parser.py index 6f2003b..a040fca 100644 --- a/tests/unit/test_odata_parser.py +++ b/tests/unit/test_odata_parser.py @@ -110,6 +110,22 @@ def test_duration_parsing(value: str, expected_unpacked: tuple): assert res.unpack() == expected_unpacked +@pytest.mark.parametrize( + "value, expected", + [ + ( + "geography'SRID=0;Point(142.1 64.1)'", + ast.Geography("SRID=0;Point(142.1 64.1)"), + ) + ], +) +def test_geography_literal_parsing(value: str, expected: str): + res = parse(value, "primitive_literal") + + assert isinstance(res, ast.Geography) + assert res == expected + + @pytest.mark.parametrize( "odata_val, exp_py_val", [ @@ -422,6 +438,13 @@ def test_bool_common_expr(expression: str, expected_ast: ast._Node): "concat('abc', 'def')", ast.Call(ast.Identifier("concat"), [ast.String("abc"), ast.String("def")]), ), + ( + "geo.distance(home,geography'SRID=0;Point(142.1 64.1)')", + ast.Call( + ast.Identifier("distance", namespace=("geo",)), + [ast.Identifier("home"), ast.Geography("SRID=0;Point(142.1 64.1)")], + ), + ), ], ) def test_common_expr(expression: str, expected_ast): diff --git a/tox.ini b/tox.ini index c247f99..529c0e1 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py37-django3, py{38,39,310}-django{3,4}, linting, docs +envlist = py37-django3, py{38,39,310,311}-django{3,4}, linting, docs skip_missing_interpreters = True isolated_build = True @@ -9,6 +9,7 @@ python = 3.8: py38, linting, docs 3.9: py39 3.10: py310 + 3.11: py311 [testenv:linting] basepython = python3.8 @@ -45,4 +46,4 @@ setenv = passenv = PYTHONBREAKPOINT commands = - pytest {posargs:tests/unit/ tests/integration/} + pytest {posargs:tests/unit/ tests/integration/} -r fEs