Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support geo.{intersects, distance, length} #50

Merged
merged 13 commits into from
Jan 23, 2024
Merged
8 changes: 5 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Install GIS libraries
run: sudo apt-get install -y binutils libproj-dev gdal-bin libsqlite3-mod-spatialite
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,5 @@ ENV/
.serverless/
node_modules/

tests/integration/django/db/*
tests/data/world_borders*
17 changes: 17 additions & 0 deletions odata_query/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class Identifier(_Node):
name: str
namespace: Tuple[str, ...] = field(default_factory=tuple)

def full_name(self):
return ".".join(self.namespace + (self.name,))


@dataclass(frozen=True)
class Attribute(_Node):
Expand Down Expand Up @@ -81,6 +84,14 @@ def py_val(self) -> str:
return self.val


@dataclass(frozen=True)
class Geography(_Literal):
val: str

def wkt(self):
return self.val


@dataclass(frozen=True)
class Date(_Literal):
val: str
Expand Down Expand Up @@ -309,6 +320,12 @@ class UnaryOp(_Node):
###############################################################################
# Function calls
###############################################################################
@dataclass(frozen=True)
class NamedParam(_Node):
name: Identifier
param: _Node


@dataclass(frozen=True)
class Call(_Node):
func: Identifier
Expand Down
55 changes: 52 additions & 3 deletions odata_query/django/django_q.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, Union
from uuid import UUID

Expand All @@ -17,6 +18,17 @@
)
from django.db.models.expressions import Expression

try:
# Django gis requires system level libraries, which not every user needs.
from django.contrib.gis.db.models import functions as gis_functions
from django.contrib.gis.geos import GEOSGeometry

_gis_error = None
except Exception as e:
gis_functions = None
GEOSGeometry = None
_gis_error = e

from odata_query import ast, exceptions as ex, typing, utils, visitor

from .django_q_ext import NotEqual
Expand All @@ -40,6 +52,18 @@
}


def requires_gis(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not gis_functions:
raise ImportError(
"Cannot use geography functions because GeoDjango failed to load."
) from _gis_error
return func(*args, **kwargs)

return wrapper


class AstToDjangoQVisitor(visitor.NodeVisitor):
"""
:class:`NodeVisitor` that transforms an :term:`AST` into a Django Q
Expand Down Expand Up @@ -254,12 +278,23 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> str:

def visit_Call(self, node: ast.Call) -> Union[Expression, Q]:
":meta private:"

func_name = node.func.full_name().replace(".", "__")

try:
q_gen = getattr(self, "djangofunc_" + node.func.name.lower())
q_gen = getattr(self, "djangofunc_" + func_name.lower())
except AttributeError:
raise ex.UnsupportedFunctionException(node.func.name)
raise ex.UnsupportedFunctionException(func_name)

args = []
kwargs = {}
for arg in node.args:
if isinstance(arg, ast.NamedParam):
kwargs[arg.name.name] = arg.param
else:
args.append(arg)

res = q_gen(*node.args)
res = q_gen(*args, **kwargs)
return res

def visit_CollectionLambda(self, node: ast.CollectionLambda) -> Q:
Expand Down Expand Up @@ -303,6 +338,20 @@ def visit_CollectionLambda(self, node: ast.CollectionLambda) -> Q:
else:
raise NotImplementedError()

@requires_gis
def djangofunc_geo__intersects(
self, field: ast.Identifier, geography: ast.Geography
):
return Q(**{field.name + "__" + "intersects": GEOSGeometry(geography.wkt())})

@requires_gis
def djangofunc_geo__distance(self, field: ast.Identifier, geography: ast.Geography):
return gis_functions.Distance(field.name, GEOSGeometry(geography.wkt()))

@requires_gis
def djangofunc_geo__length(self, field: ast.Identifier):
return gis_functions.Length(field.name)

def djangofunc_contains(
self, field: ast._Node, substr: ast._Node
) -> lookups.Contains:
Expand Down
79 changes: 59 additions & 20 deletions odata_query/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"floor": 1,
"ceiling": 1,
# Geo functions
"geo.distance": 1,
"geo.distance": 2,
"geo.length": 1,
"geo.intersects": 2,
# Set functions
Expand All @@ -67,6 +67,7 @@ class ODataLexer(Lexer):
"ODATA_IDENTIFIER",
"NULL",
"STRING",
"GEOGRAPHY",
"GUID",
"DATETIME",
"DATE",
Expand Down Expand Up @@ -95,7 +96,7 @@ class ODataLexer(Lexer):
"ALL",
"WS",
}
literals = {"(", ")", ",", "/", ":"}
literals = {"(", ")", ",", "/", ":", "="}
reflags = re.I

# Ensure MyPy doesn't lose its mind:
Expand Down Expand Up @@ -143,6 +144,13 @@ def STRING(self, t):
t.value = ast.String(val)
return t

@_(r"geography'(?:[^']|'')*'")
def GEOGRAPHY(self, t):
":meta private:"

t.value = ast.Geography(t.value[10:-1])
return t

@_(r"[\da-f]{8}-[\da-f]{4}-[\da-f]{4}-[\da-f]{4}-[\da-f]{12}")
def GUID(self, t):
":meta private:"
Expand Down Expand Up @@ -375,6 +383,7 @@ def common_expr(self, p):
"INTEGER",
"DECIMAL",
"STRING",
"GEOGRAPHY",
"BOOLEAN",
"GUID",
"DATE",
Expand Down Expand Up @@ -551,24 +560,27 @@ def common_expr(self, p):
####################################################################################
def _function_call(self, func: ast.Identifier, args: List[ast._Node]):
":meta private:"
func_name = func.name
try:
n_args_exp = ODATA_FUNCTIONS[func_name]
except KeyError:
raise exceptions.UnknownFunctionException(func_name)

n_args_given = len(args)
if isinstance(n_args_exp, int) and n_args_given != n_args_exp:
raise exceptions.ArgumentCountException(
func_name, n_args_exp, n_args_exp, n_args_given
)

if isinstance(n_args_exp, tuple) and (
n_args_given < n_args_exp[0] or n_args_given > n_args_exp[1]
):
raise exceptions.ArgumentCountException(
func_name, n_args_exp[0], n_args_exp[1], n_args_given
)

func_name = func.full_name()

if func.namespace in ((), ("geo",)):
try:
n_args_exp = ODATA_FUNCTIONS[func_name]
except KeyError:
raise exceptions.UnknownFunctionException(func_name)

n_args_given = len(args)
if isinstance(n_args_exp, int) and n_args_given != n_args_exp:
raise exceptions.ArgumentCountException(
func_name, n_args_exp, n_args_exp, n_args_given
)

if isinstance(n_args_exp, tuple) and (
n_args_given < n_args_exp[0] or n_args_given > n_args_exp[1]
):
raise exceptions.ArgumentCountException(
func_name, n_args_exp[0], n_args_exp[1], n_args_given
)

return ast.Call(func, args)

Expand All @@ -590,6 +602,33 @@ def common_expr(self, p):
args = p[1].val
return self._function_call(p[0], args)

@_('ODATA_IDENTIFIER "=" common_expr') # type:ignore[no-redef]
def named_param(self, p):
":meta private:"
return ast.NamedParam(p[0], p.common_expr)

@_('named_param BWS "," BWS named_param')
def list_named_param(self, p):
":meta private:"
return [p[0], p[4]]

@_('list_named_param BWS "," BWS named_param') # type:ignore[no-redef]
def list_named_param(self, p):
":meta private:"
return p.list_items + [p.named_param]

@_('ODATA_IDENTIFIER "(" BWS named_param BWS ")"') # type:ignore[no-redef]
def common_expr(self, p):
":meta private:"
args = [p.named_param]
return self._function_call(p[0], args)

@_('ODATA_IDENTIFIER "(" BWS list_named_param BWS ")"') # type:ignore[no-redef]
def common_expr(self, p):
":meta private:"
args = p.list_named_param
return self._function_call(p[0], args)

####################################################################################
# Misc
####################################################################################
Expand Down
2 changes: 1 addition & 1 deletion odata_query/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def infer_return_type(node: ast.Call) -> Optional[Type[ast._Node]]:
Returns:
The inferred type or ``None`` if unable to infer.
"""
func = node.func.name
func = node.func.full_name()

if func in (
"contains",
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pytest

from odata_query.grammar import ODataLexer, ODataParser
Expand All @@ -11,3 +13,11 @@ def lexer():
@pytest.fixture
def parser():
return ODataParser()


@pytest.fixture(scope="session")
def data_dir():
data_dir_path = Path(__file__).parent / "data"
data_dir_path.mkdir(exist_ok=True)

return data_dir_path
29 changes: 29 additions & 0 deletions tests/integration/django/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,32 @@ class ODataQueryConfig(AppConfig):
name = "tests.integration.django"
verbose_name = "OData Query Django test app"
default = True


class DbRouter:
"""
Ensure that GeoDjango models go to the SpatiaLite database, while other
models use the default SQLite database.
"""

GEO_APP = "django_geo"

def db_for_read(self, model, **hints):
if model._meta.app_label == self.GEO_APP:
return "geo"
return None

def db_for_write(self, model, **hints):
if model._meta.app_label == self.GEO_APP:
return "geo"
return None

def allow_relation(self, obj1, obj2, **hints):
return obj1._meta.app_label == obj2._meta.app_label

def allow_migrate(self, db: str, app_label: str, model_name=None, **hints):
if app_label != self.GEO_APP and db == "default":
return True
if app_label == self.GEO_APP and db == "geo":
return True
return False
21 changes: 18 additions & 3 deletions tests/integration/django/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from pathlib import Path

DB_DIR = Path(__file__).parent / "db"
DB_DIR.mkdir(exist_ok=True)

DATABASES = {
"default": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": "odata-query",
}
"NAME": str(DB_DIR / "odata-query"),
},
"geo": {
"ENGINE": "django.contrib.gis.db.backends.spatialite",
"NAME": str(DB_DIR / "odata-query-geo"),
},
}
DATABASE_ROUTERS = ["tests.integration.django.apps.DbRouter"]
DEBUG = True
INSTALLED_APPS = ["tests.integration.django.apps.ODataQueryConfig"]
INSTALLED_APPS = [
"tests.integration.django.apps.ODataQueryConfig",
# GEO:
"django.contrib.gis",
"tests.integration.django_geo.apps.ODataQueryConfig",
]
Empty file.
7 changes: 7 additions & 0 deletions tests/integration/django_geo/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.apps import AppConfig


class ODataQueryConfig(AppConfig):
name = "tests.integration.django_geo"
verbose_name = "OData Query GeoDjango test app"
default = True
Loading