From 8d7c3703ff9b1f161ca76b25c5c9d59ef9b48f3a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 21 Aug 2024 11:09:57 +0200 Subject: [PATCH 01/20] Fix some mypy errors --- neomodel/async_/core.py | 27 +++++++++++++++++-------- neomodel/async_/relationship.py | 6 +++++- neomodel/async_/relationship_manager.py | 10 ++------- neomodel/properties.py | 6 +++--- neomodel/sync_/core.py | 23 +++++++++++++-------- neomodel/sync_/relationship.py | 4 +++- neomodel/sync_/relationship_manager.py | 10 ++------- requirements-dev.txt | 5 ++++- 8 files changed, 53 insertions(+), 38 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 5773da12..66b58fcb 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -9,7 +9,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Type from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -90,8 +90,8 @@ class AsyncDatabase(local): A singleton object via which all operations from neomodel to the Neo4j backend are handled with. """ - _NODE_CLASS_REGISTRY = {} - _DB_SPECIFIC_CLASS_REGISTRY = {} + _NODE_CLASS_REGISTRY: dict[frozenset, type] = {} + _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict] = {} def __init__(self): self._active_transaction = None @@ -105,7 +105,9 @@ def __init__(self): self._database_edition = None self.impersonated_user = None - async def set_connection(self, url: str = None, driver: AsyncDriver = None): + async def set_connection( + self, url: str | None = None, driver: AsyncDriver | None = None + ): """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -188,8 +190,9 @@ def _parse_driver_from_url(self, url: str) -> None: options["encrypted"] = config.ENCRYPTED options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + # Ignore the type error because the workaround would be duplicating code self.driver = AsyncGraphDatabase.driver( - parsed_url.scheme + "://" + hostname, **options + parsed_url.scheme + "://" + hostname, **options # type: ignore[arg-type] ) self.url = url # The database name can be provided through the url or the config @@ -503,12 +506,18 @@ async def _run_cypher_query( except ClientError as e: if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - if "already exists with label" in e.message and handle_unique: + if ( + hasattr(e, "message") + and e.message is not None + and "already exists with label" in e.message + and handle_unique + ): raise UniqueProperty(e.message) from e raise ConstraintValidationFailed(e.message) from e exc_info = sys.exc_info() - raise exc_info[1].with_traceback(exc_info[2]) + if exc_info[1] is not None and exc_info[2] is not None: + raise exc_info[1].with_traceback(exc_info[2]) except SessionExpired: if retry_on_session_expire: await self.set_connection(url=self.url) @@ -1340,7 +1349,9 @@ def build_class_registry(cls): ) -NodeBase = NodeMeta("NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True}) +NodeBase: Type = NodeMeta( + "NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True} +) class AsyncStructuredNode(NodeBase): diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index 818bc3a2..365dd132 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -1,3 +1,5 @@ +from typing import Type + from neomodel.async_.core import adb from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.hooks import hooks @@ -38,7 +40,9 @@ def __new__(mcs, name, bases, dct): return inst -StructuredRelBase = RelationshipMeta("RelationshipBase", (AsyncPropertyManager,), {}) +StructuredRelBase: Type = RelationshipMeta( + "RelationshipBase", (AsyncPropertyManager,), {} +) class AsyncStructuredRel(StructuredRelBase): diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 35c07f0c..d26423a7 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -20,12 +20,6 @@ get_graph_entity_properties, ) -# basestring python 3.x fallback -try: - basestring -except NameError: - basestring = str - # check source node is saved and not deleted def check_source(fn): @@ -469,14 +463,14 @@ def __init__( adb._NODE_CLASS_REGISTRY[label_set] = model def _validate_class(self, cls_name, model): - if not isinstance(cls_name, (basestring, object)): + if not isinstance(cls_name, (str, object)): raise ValueError("Expected class name or class got " + repr(cls_name)) if model and not issubclass(model, (AsyncStructuredRel,)): raise ValueError("model must be a StructuredRel") def lookup_node_class(self): - if not isinstance(self._raw_class, basestring): + if not isinstance(self._raw_class, str): self.definition["node_class"] = self._raw_class else: name = self._raw_class diff --git a/neomodel/properties.py b/neomodel/properties.py index d4a91885..029b3712 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -106,8 +106,8 @@ def __init__( self, unique_index=False, index=False, - fulltext_index: FulltextIndex = None, - vector_index: VectorIndex = None, + fulltext_index: FulltextIndex | None = None, + vector_index: VectorIndex | None = None, required=False, default=None, db_property=None, @@ -192,7 +192,7 @@ class RegexProperty(NormalizedProperty): form_field_class = "RegexField" - expression = None + expression: str | None = None def __init__(self, expression=None, **kwargs): """ diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 6c72908a..b3854603 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -9,7 +9,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Type from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -90,8 +90,8 @@ class Database(local): A singleton object via which all operations from neomodel to the Neo4j backend are handled with. """ - _NODE_CLASS_REGISTRY = {} - _DB_SPECIFIC_CLASS_REGISTRY = {} + _NODE_CLASS_REGISTRY: dict[frozenset, type] = {} + _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict] = {} def __init__(self): self._active_transaction = None @@ -105,7 +105,7 @@ def __init__(self): self._database_edition = None self.impersonated_user = None - def set_connection(self, url: str = None, driver: Driver = None): + def set_connection(self, url: str | None = None, driver: Driver | None = None): """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -188,8 +188,9 @@ def _parse_driver_from_url(self, url: str) -> None: options["encrypted"] = config.ENCRYPTED options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + # Ignore the type error because the workaround would be duplicating code self.driver = GraphDatabase.driver( - parsed_url.scheme + "://" + hostname, **options + parsed_url.scheme + "://" + hostname, **options # type: ignore[arg-type] ) self.url = url # The database name can be provided through the url or the config @@ -501,12 +502,18 @@ def _run_cypher_query( except ClientError as e: if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - if "already exists with label" in e.message and handle_unique: + if ( + hasattr(e, "message") + and e.message is not None + and "already exists with label" in e.message + and handle_unique + ): raise UniqueProperty(e.message) from e raise ConstraintValidationFailed(e.message) from e exc_info = sys.exc_info() - raise exc_info[1].with_traceback(exc_info[2]) + if exc_info[1] is not None and exc_info[2] is not None: + raise exc_info[1].with_traceback(exc_info[2]) except SessionExpired: if retry_on_session_expire: self.set_connection(url=self.url) @@ -1334,7 +1341,7 @@ def build_class_registry(cls): ) -NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) +NodeBase: Type = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) class StructuredNode(NodeBase): diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index a673fb8b..9b3bdf95 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,3 +1,5 @@ +from typing import Type + from neomodel.hooks import hooks from neomodel.properties import Property from neomodel.sync_.core import db @@ -38,7 +40,7 @@ def __new__(mcs, name, bases, dct): return inst -StructuredRelBase = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) +StructuredRelBase: Type = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) class StructuredRel(StructuredRelBase): diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index fadaee99..14323bcc 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -15,12 +15,6 @@ get_graph_entity_properties, ) -# basestring python 3.x fallback -try: - basestring -except NameError: - basestring = str - # check source node is saved and not deleted def check_source(fn): @@ -452,14 +446,14 @@ def __init__( db._NODE_CLASS_REGISTRY[label_set] = model def _validate_class(self, cls_name, model): - if not isinstance(cls_name, (basestring, object)): + if not isinstance(cls_name, (str, object)): raise ValueError("Expected class name or class got " + repr(cls_name)) if model and not issubclass(model, (StructuredRel,)): raise ValueError("model must be a StructuredRel") def lookup_node_class(self): - if not isinstance(self._raw_class, basestring): + if not isinstance(self._raw_class, str): self.definition["node_class"] = self._raw_class else: name = self._raw_class diff --git a/requirements-dev.txt b/requirements-dev.txt index bf3fa116..d540f037 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,4 +7,7 @@ pytest-cov>=4.0 pre-commit black isort -Shapely>=2.0.0 \ No newline at end of file +Shapely>=2.0.0 +mypy>=1.11 +pandas-stubs +types-pytz \ No newline at end of file From 394860d5adad854b1dc9504d0929e07b777146d9 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 20 Sep 2024 17:14:57 +0200 Subject: [PATCH 02/20] Implement some mypy fixes --- neomodel/async_/core.py | 40 +++++++++++++++------- neomodel/async_/match.py | 24 +++++++++----- neomodel/async_/path.py | 39 ++++++++++++++++++---- neomodel/async_/relationship_manager.py | 9 ++++- neomodel/exceptions.py | 7 ++-- neomodel/properties.py | 44 ++++++++++++++++--------- neomodel/sync_/core.py | 40 ++++++++++++++++------ neomodel/sync_/match.py | 24 +++++++++----- neomodel/sync_/path.py | 39 ++++++++++++++++++---- neomodel/sync_/relationship_manager.py | 9 ++++- test/async_/test_paths.py | 25 +++++++++----- test/async_/test_properties.py | 2 +- test/sync_/test_paths.py | 25 +++++++++----- test/sync_/test_properties.py | 2 +- 14 files changed, 239 insertions(+), 90 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 66b58fcb..cf482b58 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -90,14 +90,14 @@ class AsyncDatabase(local): A singleton object via which all operations from neomodel to the Neo4j backend are handled with. """ - _NODE_CLASS_REGISTRY: dict[frozenset, type] = {} + _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict] = {} def __init__(self): - self._active_transaction = None + self._active_transaction: Optional[AsyncTransaction] = None self.url = None self.driver = None - self._session = None + self._session: Optional[AsyncSession] = None self._pid = None self._database_name = DEFAULT_DATABASE self.protocol_version = None @@ -106,7 +106,7 @@ def __init__(self): self.impersonated_user = None async def set_connection( - self, url: str | None = None, driver: AsyncDriver | None = None + self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None ): """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -210,8 +210,9 @@ async def close_connection(self): self._database_version = None self._database_edition = None self._database_name = None - await self.driver.close() - self.driver = None + if self.driver is not None: + await self.driver.close() + self.driver = None @property async def database_version(self): @@ -268,15 +269,18 @@ async def begin(self, access_mode=None, **parameters): and self._active_transaction is not None ): raise SystemError("Transaction in progress") - self._session: AsyncSession = self.driver.session( + + assert self.driver is not None, "Driver has not been created" + + self._session = self.driver.session( default_access_mode=access_mode, database=self._database_name, impersonated_user=self.impersonated_user, **parameters, ) - self._active_transaction: AsyncTransaction = ( - await self._session.begin_transaction() - ) + + assert self._session is not None, "Session has not been created" + self._active_transaction = await self._session.begin_transaction() @ensure_connection async def commit(self): @@ -286,14 +290,21 @@ async def commit(self): :return: last_bookmarks """ try: + assert self._active_transaction is not None, "No transaction in progress" await self._active_transaction.commit() + + assert self._session is not None, "No session open" last_bookmarks: Bookmarks = await self._session.last_bookmarks() finally: - # In case when something went wrong during + # In case something went wrong during # committing changes to the database # we have to close an active transaction and session. + assert self._active_transaction is not None, "No transaction in progress" await self._active_transaction.close() + + assert self._session is not None, "No session open" await self._session.close() + self._active_transaction = None self._session = None @@ -305,12 +316,17 @@ async def rollback(self): Rolls back the current transaction and closes its session """ try: + assert self._active_transaction is not None, "No transaction in progress" await self._active_transaction.rollback() finally: # In case when something went wrong during changes rollback, # we have to close an active transaction and session + assert self._active_transaction is not None, "No transaction in progress" await self._active_transaction.close() + + assert self._session is not None, "No session open" await self._session.close() + self._active_transaction = None self._session = None @@ -459,7 +475,7 @@ async def cypher_query( """ if self._active_transaction: - # Use current session is a transaction is currently active + # Use current transaction if a transaction is currently active results, meta = await self._run_cypher_query( self._active_transaction, query, diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 7f0435fe..9f6b4ee4 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.exceptions import MultipleNodesReturned @@ -772,10 +772,11 @@ async def _execute(self, lazy=False): f"{await adb.get_id_method()}({self._ast.return_clause})" ) else: - self._ast.additional_return = [ - f"{await adb.get_id_method()}({item})" - for item in self._ast.additional_return - ] + if self._ast.additional_return is not None: + self._ast.additional_return = [ + f"{await adb.get_id_method()}({item})" + for item in self._ast.additional_return + ] query = self.build_query() results, _ = await adb.cypher_query( query, self._query_params, resolve_objects=True @@ -875,7 +876,7 @@ async def get_item(self, key): @dataclass -class Optional: +class Optional: # type: ignore[no-redef] """Simple relation qualifier.""" relation: str @@ -1066,7 +1067,7 @@ def fetch_relations(self, *relation_names): """Specify a set of relations to return.""" relations = [] for relation_name in relation_names: - if isinstance(relation_name, Optional): + if isinstance(relation_name, Optional): # type: ignore[arg-type] item = {"path": relation_name.relation, "optional": True} else: item = {"path": relation_name} @@ -1090,10 +1091,17 @@ class AsyncTraversal(AsyncBaseSet): :type defintion: :class:`dict` """ + definition: dict + source: Any + source_class: Any + target_class: Any + name: str + filters: list + def __await__(self): return self.all().__await__() - def __init__(self, source, name, definition): + def __init__(self, source: Any, name: str, definition: dict): """ Create a traversal diff --git a/neomodel/async_/path.py b/neomodel/async_/path.py index 6128347e..e04cf202 100644 --- a/neomodel/async_/path.py +++ b/neomodel/async_/path.py @@ -1,10 +1,12 @@ +import typing as t + from neo4j.graph import Path -from neomodel.async_.core import adb +from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.relationship import AsyncStructuredRel -class AsyncNeomodelPath(Path): +class AsyncNeomodelPath(object): """ Represents paths within neomodel. @@ -26,9 +28,9 @@ class AsyncNeomodelPath(Path): :type relationships: List[StructuredRel] """ - def __init__(self, a_neopath): - self._nodes = [] - self._relationships = [] + def __init__(self, a_neopath: Path): + self._nodes: list[AsyncStructuredNode] = [] + self._relationships: list[AsyncStructuredRel] = [] for a_node in a_neopath.nodes: self._nodes.append(adb._object_resolution(a_node)) @@ -44,10 +46,33 @@ def __init__(self, a_neopath): new_rel = AsyncStructuredRel.inflate(a_relationship) self._relationships.append(new_rel) + def __repr__(self) -> str: + return "" % ( + self.start_node, + self.end_node, + len(self), + ) + + def __len__(self) -> int: + return len(self._relationships) + + def __iter__(self) -> t.Iterator[AsyncStructuredRel]: + return iter(self._relationships) + @property - def nodes(self): + def nodes(self) -> list[AsyncStructuredNode]: return self._nodes @property - def relationships(self): + def start_node(self) -> AsyncStructuredNode: + """The first :class:`.StructuredNode` in this path.""" + return self._nodes[0] + + @property + def end_node(self) -> AsyncStructuredNode: + """The last :class:`.StructuredNode` in this path.""" + return self._nodes[-1] + + @property + def relationships(self) -> list[AsyncStructuredRel]: return self._relationships diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index d26423a7..a3d9b27a 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -2,6 +2,7 @@ import inspect import sys from importlib import import_module +from typing import Any from neomodel.async_.core import adb from neomodel.async_.match import ( @@ -48,7 +49,13 @@ class AsyncRelationshipManager(object): I.e the 'friends' object in `user.friends.all()` """ - def __init__(self, source, key, definition): + source: Any + source_class: Any + name: str + definition: dict + description: str = "relationship" + + def __init__(self, source: Any, key: str, definition: dict): self.source = source self.source_class = source.__class__ self.name = key diff --git a/neomodel/exceptions.py b/neomodel/exceptions.py index cd66f962..3959c347 100644 --- a/neomodel/exceptions.py +++ b/neomodel/exceptions.py @@ -1,3 +1,6 @@ +from typing import Optional, Type + + class NeomodelException(Exception): """ A base class that identifies all exceptions raised by :mod:`neomodel`. @@ -182,7 +185,7 @@ def __str__(self): class DoesNotExist(NeomodelException): - _model_class = None + _model_class: Optional[Type] = None """ This class property refers the model class that a subclass of this class belongs to. It is set by :class:`~neomodel.core.NodeMeta`. @@ -191,7 +194,7 @@ class DoesNotExist(NeomodelException): def __init__(self, msg): if self._model_class is None: raise RuntimeError("This class hasn't been setup properly.") - self.message = msg + self.message: str = msg super().__init__(self, msg) def __reduce__(self): diff --git a/neomodel/properties.py b/neomodel/properties.py index 029b3712..fd0dd1ba 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -1,26 +1,22 @@ import functools import json import re -import sys import uuid from datetime import date, datetime +from typing import Any, Optional import neo4j.time import pytz from neomodel import config -from neomodel.exceptions import DeflateError, InflateError +from neomodel.exceptions import DeflateError, InflateError, NeomodelException TOO_MANY_DEFAULTS = "too many defaults" def validator(fn): fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ - if fn_name == "inflate": - exc_class = InflateError - elif fn_name == "deflate": - exc_class = DeflateError - else: + if fn_name not in ["inflate", "deflate"]: raise ValueError("Unknown Property method " + fn_name) @functools.wraps(fn) @@ -29,7 +25,12 @@ def _validator(self, value, obj=None, rethrow=True): try: return fn(self, value) except Exception as e: - raise exc_class(self.name, self.owner, str(e), obj) from e + if fn_name == "inflate": + raise InflateError(self.name, self.owner, str(e), obj) from e + elif fn_name == "deflate": + raise DeflateError(self.name, self.owner, str(e), obj) from e + else: + raise NeomodelException("Unknown Property method " + fn_name) from e else: # For using with ArrayProperty where we don't want an Inflate/Deflate error. return fn(self, value) @@ -100,14 +101,27 @@ class Property: """ form_field_class = "CharField" + name: Optional[str] = None + owner: Optional[Any] = None + unique_index: bool = False + index: bool = False + fulltext_index: Optional[FulltextIndex] = None + vector_index: Optional[VectorIndex] = None + required: bool = False + default: Any = None + db_property: Optional[str] = None + label: Optional[str] = None + help_text: Optional[str] = None # pylint:disable=unused-argument def __init__( self, + name: Optional[str] = None, + owner: Optional[Any] = None, unique_index=False, index=False, - fulltext_index: FulltextIndex | None = None, - vector_index: VectorIndex | None = None, + fulltext_index: Optional[FulltextIndex] = None, + vector_index: Optional[VectorIndex] = None, required=False, default=None, db_property=None, @@ -192,7 +206,7 @@ class RegexProperty(NormalizedProperty): form_field_class = "RegexField" - expression: str | None = None + expression: str def __init__(self, expression=None, **kwargs): """ @@ -201,10 +215,7 @@ def __init__(self, expression=None, **kwargs): :param str expression: regular expression validating this property """ super().__init__(**kwargs) - actual_re = expression or self.expression - if actual_re is None: - raise ValueError("expression is undefined") - self.expression = actual_re + self.expression = expression or self.expression def normalize(self, value): normal = str(value) @@ -553,6 +564,7 @@ def __init__(self, to=None): :param to: name of property aliasing :type: str """ + super().__init__() self.target = to self.required = False self.has_default = False @@ -560,7 +572,7 @@ def __init__(self, to=None): def aliased_to(self): return self.target - def __get__(self, obj, cls): + def __get__(self, obj: Any, _type: Optional[Any] = None): return getattr(obj, self.aliased_to()) if obj else self def __set__(self, obj, value): diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index b3854603..ca2fe816 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -90,14 +90,14 @@ class Database(local): A singleton object via which all operations from neomodel to the Neo4j backend are handled with. """ - _NODE_CLASS_REGISTRY: dict[frozenset, type] = {} + _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict] = {} def __init__(self): - self._active_transaction = None + self._active_transaction: Optional[Transaction] = None self.url = None self.driver = None - self._session = None + self._session: Optional[Session] = None self._pid = None self._database_name = DEFAULT_DATABASE self.protocol_version = None @@ -105,7 +105,9 @@ def __init__(self): self._database_edition = None self.impersonated_user = None - def set_connection(self, url: str | None = None, driver: Driver | None = None): + def set_connection( + self, url: Optional[str] = None, driver: Optional[Driver] = None + ): """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -208,8 +210,9 @@ def close_connection(self): self._database_version = None self._database_edition = None self._database_name = None - self.driver.close() - self.driver = None + if self.driver is not None: + self.driver.close() + self.driver = None @property def database_version(self): @@ -266,13 +269,18 @@ def begin(self, access_mode=None, **parameters): and self._active_transaction is not None ): raise SystemError("Transaction in progress") - self._session: Session = self.driver.session( + + assert self.driver is not None, "Driver has not been created" + + self._session = self.driver.session( default_access_mode=access_mode, database=self._database_name, impersonated_user=self.impersonated_user, **parameters, ) - self._active_transaction: Transaction = self._session.begin_transaction() + + assert self._session is not None, "Session has not been created" + self._active_transaction = self._session.begin_transaction() @ensure_connection def commit(self): @@ -282,14 +290,21 @@ def commit(self): :return: last_bookmarks """ try: + assert self._active_transaction is not None, "No transaction in progress" self._active_transaction.commit() + + assert self._session is not None, "No session open" last_bookmarks: Bookmarks = self._session.last_bookmarks() finally: - # In case when something went wrong during + # In case something went wrong during # committing changes to the database # we have to close an active transaction and session. + assert self._active_transaction is not None, "No transaction in progress" self._active_transaction.close() + + assert self._session is not None, "No session open" self._session.close() + self._active_transaction = None self._session = None @@ -301,12 +316,17 @@ def rollback(self): Rolls back the current transaction and closes its session """ try: + assert self._active_transaction is not None, "No transaction in progress" self._active_transaction.rollback() finally: # In case when something went wrong during changes rollback, # we have to close an active transaction and session + assert self._active_transaction is not None, "No transaction in progress" self._active_transaction.close() + + assert self._session is not None, "No session open" self._session.close() + self._active_transaction = None self._session = None @@ -455,7 +475,7 @@ def cypher_query( """ if self._active_transaction: - # Use current session is a transaction is currently active + # Use current transaction if a transaction is currently active results, meta = self._run_cypher_query( self._active_transaction, query, diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 928842c2..59c74e8e 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase @@ -770,10 +770,11 @@ def _execute(self, lazy=False): f"{db.get_id_method()}({self._ast.return_clause})" ) else: - self._ast.additional_return = [ - f"{db.get_id_method()}({item})" - for item in self._ast.additional_return - ] + if self._ast.additional_return is not None: + self._ast.additional_return = [ + f"{db.get_id_method()}({item})" + for item in self._ast.additional_return + ] query = self.build_query() results, _ = db.cypher_query(query, self._query_params, resolve_objects=True) # The following is not as elegant as it could be but had to be copied from the @@ -871,7 +872,7 @@ def __getitem__(self, key): @dataclass -class Optional: +class Optional: # type: ignore[no-redef] """Simple relation qualifier.""" relation: str @@ -1062,7 +1063,7 @@ def fetch_relations(self, *relation_names): """Specify a set of relations to return.""" relations = [] for relation_name in relation_names: - if isinstance(relation_name, Optional): + if isinstance(relation_name, Optional): # type: ignore[arg-type] item = {"path": relation_name.relation, "optional": True} else: item = {"path": relation_name} @@ -1086,10 +1087,17 @@ class Traversal(BaseSet): :type defintion: :class:`dict` """ + definition: dict + source: Any + source_class: Any + target_class: Any + name: str + filters: list + def __await__(self): return self.all().__await__() - def __init__(self, source, name, definition): + def __init__(self, source: Any, name: str, definition: dict): """ Create a traversal diff --git a/neomodel/sync_/path.py b/neomodel/sync_/path.py index 62a49fe7..90058dd2 100644 --- a/neomodel/sync_/path.py +++ b/neomodel/sync_/path.py @@ -1,10 +1,12 @@ +import typing as t + from neo4j.graph import Path -from neomodel.sync_.core import db +from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.relationship import StructuredRel -class NeomodelPath(Path): +class NeomodelPath(object): """ Represents paths within neomodel. @@ -26,9 +28,9 @@ class NeomodelPath(Path): :type relationships: List[StructuredRel] """ - def __init__(self, a_neopath): - self._nodes = [] - self._relationships = [] + def __init__(self, a_neopath: Path): + self._nodes: list[StructuredNode] = [] + self._relationships: list[StructuredRel] = [] for a_node in a_neopath.nodes: self._nodes.append(db._object_resolution(a_node)) @@ -44,10 +46,33 @@ def __init__(self, a_neopath): new_rel = StructuredRel.inflate(a_relationship) self._relationships.append(new_rel) + def __repr__(self) -> str: + return "" % ( + self.start_node, + self.end_node, + len(self), + ) + + def __len__(self) -> int: + return len(self._relationships) + + def __iter__(self) -> t.Iterator[StructuredRel]: + return iter(self._relationships) + @property - def nodes(self): + def nodes(self) -> list[StructuredNode]: return self._nodes @property - def relationships(self): + def start_node(self) -> StructuredNode: + """The first :class:`.StructuredNode` in this path.""" + return self._nodes[0] + + @property + def end_node(self) -> StructuredNode: + """The last :class:`.StructuredNode` in this path.""" + return self._nodes[-1] + + @property + def relationships(self) -> list[StructuredRel]: return self._relationships diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 14323bcc..a4a3cc69 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -2,6 +2,7 @@ import inspect import sys from importlib import import_module +from typing import Any from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.sync_.core import db @@ -43,7 +44,13 @@ class RelationshipManager(object): I.e the 'friends' object in `user.friends.all()` """ - def __init__(self, source, key, definition): + source: Any + source_class: Any + name: str + definition: dict + description: str = "relationship" + + def __init__(self, source: Any, key: str, definition: dict): self.source = source self.source_class = source.__class__ self.name = key diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py index 59a5e385..723bf8e1 100644 --- a/test/async_/test_paths.py +++ b/test/async_/test_paths.py @@ -70,7 +70,7 @@ async def test_path_instantiation(): # Retrieve a single path q = await adb.cypher_query( - "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + "MATCH p=(:CityOfResidence{name:'Athens'})<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects=True, ) @@ -78,13 +78,22 @@ async def test_path_instantiation(): path_nodes = path_object.nodes path_rels = path_object.relationships - assert type(path_object) is AsyncNeomodelPath - assert type(path_nodes[0]) is CityOfResidence - assert type(path_nodes[1]) is PersonOfInterest - assert type(path_nodes[2]) is CountryOfOrigin - - assert type(path_rels[0]) is PersonLivesInCity - assert type(path_rels[1]) is AsyncStructuredRel + assert isinstance(path_object, AsyncNeomodelPath) + assert isinstance(path_nodes[0], CityOfResidence) + assert isinstance(path_nodes[1], PersonOfInterest) + assert isinstance(path_nodes[2], CountryOfOrigin) + assert isinstance(path_object.start_node, CityOfResidence) + assert isinstance(path_object.end_node, CountryOfOrigin) + + assert isinstance(path_rels[0], PersonLivesInCity) + assert isinstance(path_rels[1], AsyncStructuredRel) + + path_string = str(path_object) + assert path_string.startswith("") + assert len(path_object) == 2 + for rel in path_object: + assert isinstance(rel, AsyncStructuredRel) await c1.delete() await c2.delete() diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 4f3eab2d..e512eb83 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -510,7 +510,7 @@ def test_regex_property(): class MissingExpression(RegexProperty): pass - with raises(ValueError): + with raises(AttributeError): MissingExpression() class TestProperty(RegexProperty): diff --git a/test/sync_/test_paths.py b/test/sync_/test_paths.py index 8e0ccf90..048343ab 100644 --- a/test/sync_/test_paths.py +++ b/test/sync_/test_paths.py @@ -70,7 +70,7 @@ def test_path_instantiation(): # Retrieve a single path q = db.cypher_query( - "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + "MATCH p=(:CityOfResidence{name:'Athens'})<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects=True, ) @@ -78,13 +78,22 @@ def test_path_instantiation(): path_nodes = path_object.nodes path_rels = path_object.relationships - assert type(path_object) is NeomodelPath - assert type(path_nodes[0]) is CityOfResidence - assert type(path_nodes[1]) is PersonOfInterest - assert type(path_nodes[2]) is CountryOfOrigin - - assert type(path_rels[0]) is PersonLivesInCity - assert type(path_rels[1]) is StructuredRel + assert isinstance(path_object, NeomodelPath) + assert isinstance(path_nodes[0], CityOfResidence) + assert isinstance(path_nodes[1], PersonOfInterest) + assert isinstance(path_nodes[2], CountryOfOrigin) + assert isinstance(path_object.start_node, CityOfResidence) + assert isinstance(path_object.end_node, CountryOfOrigin) + + assert isinstance(path_rels[0], PersonLivesInCity) + assert isinstance(path_rels[1], StructuredRel) + + path_string = str(path_object) + assert path_string.startswith("") + assert len(path_object) == 2 + for rel in path_object: + assert isinstance(rel, StructuredRel) c1.delete() c2.delete() diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 1afe52a2..d30fa4fd 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -500,7 +500,7 @@ def test_regex_property(): class MissingExpression(RegexProperty): pass - with raises(ValueError): + with raises(AttributeError): MissingExpression() class TestProperty(RegexProperty): From 4ed282a2eab57cc27b030bacd894ead101712a7e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 4 Dec 2024 17:45:09 +0100 Subject: [PATCH 03/20] Add more type fixes --- neomodel/async_/core.py | 66 ++++++++++++++++++++++++---------------- neomodel/async_/match.py | 25 ++++++++++----- neomodel/sync_/core.py | 66 ++++++++++++++++++++++++---------------- neomodel/sync_/match.py | 25 ++++++++++----- 4 files changed, 114 insertions(+), 68 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 8988d6cc..6fd7fe64 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -9,7 +9,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Optional, Sequence, Type +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -95,16 +95,15 @@ class AsyncDatabase(local): def __init__(self): self._active_transaction: Optional[AsyncTransaction] = None - self.url = None - self.driver = None + self.url: Optional[str] = None + self.driver: Optional[AsyncDriver] = None self._session: Optional[AsyncSession] = None - self._pid = None - self._database_name = DEFAULT_DATABASE - self.protocol_version = None - self._database_version = None - self._database_edition = None - self.impersonated_user = None - self._parallel_runtime = False + self._pid: Optional[int] = None + self._database_name: Optional[str] = DEFAULT_DATABASE + self._database_version: Optional[str] = None + self._database_edition: Optional[str] = None + self.impersonated_user: Optional[str] = None + self._parallel_runtime: Optional[bool] = False async def set_connection( self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None @@ -490,23 +489,27 @@ async def cypher_query( ) else: # Otherwise create a new session in a with to dispose of it after it has been run - async with self.driver.session( - database=self._database_name, impersonated_user=self.impersonated_user - ) as session: - results, meta = await self._run_cypher_query( - session, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ) + if self.driver: + async with self.driver.session( + database=self._database_name, + impersonated_user=self.impersonated_user, + ) as session: + results, meta = await self._run_cypher_query( + session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + else: + raise ValueError("No driver has been set") return results, meta async def _run_cypher_query( self, - session: AsyncSession, + session: Union[AsyncSession, AsyncTransaction], query, params, handle_unique, @@ -1212,11 +1215,14 @@ class AsyncTransactionProxy: bookmarks: Optional[Bookmarks] = None def __init__( - self, db: AsyncDatabase, access_mode: str = None, parallel_runtime: bool = False + self, + db: AsyncDatabase, + access_mode: Optional[str] = None, + parallel_runtime: Optional[bool] = False, ): - self.db = db - self.access_mode = access_mode - self.parallel_runtime = parallel_runtime + self.db: AsyncDatabase = db + self.access_mode: Optional[str] = access_mode + self.parallel_runtime: Optional[bool] = parallel_runtime @ensure_connection async def __aenter__(self): @@ -1304,6 +1310,14 @@ def wrapper(*args, **kwargs): class NodeMeta(type): + DoesNotExist: Type[DoesNotExist] + __required_properties__: Tuple[str, ...] + __all_properties__: Tuple[str, Any] + __all_aliases__: Tuple[str, Any] + __all_relationships__: Tuple[str, Any] + __label__: str + __optional_labels__: list[str] + def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) cls = super().__new__(mcs, name, bases, namespace) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index d14d1bff..29f039ec 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1,10 +1,8 @@ import inspect import re import string -import warnings from dataclasses import dataclass from typing import Any, Dict, List -from typing import Optional from typing import Optional as TOptional from typing import Tuple, Union @@ -381,7 +379,7 @@ class QueryAST: limit: TOptional[int] result_class: TOptional[type] lookup: TOptional[str] - additional_return: List[str] + additional_return: TOptional[List[str]] is_count: TOptional[bool] def __init__( @@ -409,7 +407,9 @@ def __init__( self.limit = limit self.result_class = result_class self.lookup = lookup - self.additional_return = additional_return if additional_return else [] + self.additional_return: List[str] = ( + additional_return if additional_return else [] + ) self.is_count = is_count self.subgraph: Dict = {} @@ -528,7 +528,11 @@ async def build_traversal(self, traversal) -> str: return traversal_ident def _additional_return(self, name: str): - if name not in self._ast.additional_return and name != self._ast.return_clause: + if ( + not self._ast.additional_return or name not in self._ast.additional_return + ) and name != self._ast.return_clause: + if not self._ast.additional_return: + self._ast.additional_return = [] self._ast.additional_return.append(name) def build_traversal_from_path( @@ -953,8 +957,10 @@ async def _count(self): async def _contains(self, node_element_id): # inject id = into ast - if not self._ast.return_clause: + if not self._ast.return_clause and self._ast.additional_return: self._ast.return_clause = self._ast.additional_return[0] + if not self._ast.return_clause: + raise ValueError("Cannot use contains without a return clause") ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") self._ast.where.append( @@ -1006,7 +1012,7 @@ class AsyncBaseSet: """ query_cls = AsyncQueryBuilder - source_class: AsyncStructuredNode + source_class: type[AsyncStructuredNode] async def all(self, lazy=False): """ @@ -1541,7 +1547,10 @@ async def subquery( for var in return_set: if ( var != qbuilder._ast.return_clause - and var not in qbuilder._ast.additional_return + and ( + not qbuilder._ast.additional_return + or var not in qbuilder._ast.additional_return + ) and var not in [res["alias"] for res in nodeset._extra_results if res["alias"]] ): diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index c77ef81b..880de6ca 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -9,7 +9,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Optional, Sequence, Type +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -95,16 +95,15 @@ class Database(local): def __init__(self): self._active_transaction: Optional[Transaction] = None - self.url = None - self.driver = None + self.url: Optional[str] = None + self.driver: Optional[Driver] = None self._session: Optional[Session] = None - self._pid = None - self._database_name = DEFAULT_DATABASE - self.protocol_version = None - self._database_version = None - self._database_edition = None - self.impersonated_user = None - self._parallel_runtime = False + self._pid: Optional[int] = None + self._database_name: Optional[str] = DEFAULT_DATABASE + self._database_version: Optional[str] = None + self._database_edition: Optional[str] = None + self.impersonated_user: Optional[str] = None + self._parallel_runtime: Optional[bool] = False def set_connection( self, url: Optional[str] = None, driver: Optional[Driver] = None @@ -490,23 +489,27 @@ def cypher_query( ) else: # Otherwise create a new session in a with to dispose of it after it has been run - with self.driver.session( - database=self._database_name, impersonated_user=self.impersonated_user - ) as session: - results, meta = self._run_cypher_query( - session, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ) + if self.driver: + with self.driver.session( + database=self._database_name, + impersonated_user=self.impersonated_user, + ) as session: + results, meta = self._run_cypher_query( + session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + else: + raise ValueError("No driver has been set") return results, meta def _run_cypher_query( self, - session: Session, + session: Union[Session, Transaction], query, params, handle_unique, @@ -1205,11 +1208,14 @@ class TransactionProxy: bookmarks: Optional[Bookmarks] = None def __init__( - self, db: Database, access_mode: str = None, parallel_runtime: bool = False + self, + db: Database, + access_mode: Optional[str] = None, + parallel_runtime: Optional[bool] = False, ): - self.db = db - self.access_mode = access_mode - self.parallel_runtime = parallel_runtime + self.db: Database = db + self.access_mode: Optional[str] = access_mode + self.parallel_runtime: Optional[bool] = parallel_runtime @ensure_connection def __enter__(self): @@ -1297,6 +1303,14 @@ def wrapper(*args, **kwargs): class NodeMeta(type): + DoesNotExist: Type[DoesNotExist] + __required_properties__: Tuple[str, ...] + __all_properties__: Tuple[str, Any] + __all_aliases__: Tuple[str, Any] + __all_relationships__: Tuple[str, Any] + __label__: str + __optional_labels__: list[str] + def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) cls = super().__new__(mcs, name, bases, namespace) diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index cdec9d31..102a5c59 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,10 +1,8 @@ import inspect import re import string -import warnings from dataclasses import dataclass from typing import Any, Dict, List -from typing import Optional from typing import Optional as TOptional from typing import Tuple, Union @@ -381,7 +379,7 @@ class QueryAST: limit: TOptional[int] result_class: TOptional[type] lookup: TOptional[str] - additional_return: List[str] + additional_return: TOptional[List[str]] is_count: TOptional[bool] def __init__( @@ -409,7 +407,9 @@ def __init__( self.limit = limit self.result_class = result_class self.lookup = lookup - self.additional_return = additional_return if additional_return else [] + self.additional_return: List[str] = ( + additional_return if additional_return else [] + ) self.is_count = is_count self.subgraph: Dict = {} @@ -528,7 +528,11 @@ def build_traversal(self, traversal) -> str: return traversal_ident def _additional_return(self, name: str): - if name not in self._ast.additional_return and name != self._ast.return_clause: + if ( + not self._ast.additional_return or name not in self._ast.additional_return + ) and name != self._ast.return_clause: + if not self._ast.additional_return: + self._ast.additional_return = [] self._ast.additional_return.append(name) def build_traversal_from_path( @@ -953,8 +957,10 @@ def _count(self): def _contains(self, node_element_id): # inject id = into ast - if not self._ast.return_clause: + if not self._ast.return_clause and self._ast.additional_return: self._ast.return_clause = self._ast.additional_return[0] + if not self._ast.return_clause: + raise ValueError("Cannot use contains without a return clause") ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") self._ast.where.append(f"{db.get_id_method()}({ident}) = ${place_holder}") @@ -1004,7 +1010,7 @@ class BaseSet: """ query_cls = QueryBuilder - source_class: StructuredNode + source_class: type[StructuredNode] def all(self, lazy=False): """ @@ -1537,7 +1543,10 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": for var in return_set: if ( var != qbuilder._ast.return_clause - and var not in qbuilder._ast.additional_return + and ( + not qbuilder._ast.additional_return + or var not in qbuilder._ast.additional_return + ) and var not in [res["alias"] for res in nodeset._extra_results if res["alias"]] ): From a7bc24837ef2beacaa846b880c8426cc51e4f8bc Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 5 Dec 2024 13:47:03 +0100 Subject: [PATCH 04/20] More mypy fixes --- neomodel/async_/core.py | 16 +++--- neomodel/async_/match.py | 54 +++++++++---------- neomodel/async_/path.py | 8 +-- neomodel/async_/property_manager.py | 5 +- neomodel/scripts/neomodel_generate_diagram.py | 30 ++++++----- neomodel/scripts/neomodel_inspect_database.py | 7 +-- neomodel/sync_/core.py | 16 +++--- neomodel/sync_/match.py | 54 +++++++++---------- neomodel/sync_/path.py | 8 +-- neomodel/sync_/property_manager.py | 5 +- pyproject.toml | 11 +++- requirements-dev.txt | 3 +- 12 files changed, 119 insertions(+), 98 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 6fd7fe64..6c2def38 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import logging import os import sys @@ -9,7 +7,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Optional, Sequence, Type, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -1311,13 +1309,15 @@ def wrapper(*args, **kwargs): class NodeMeta(type): DoesNotExist: Type[DoesNotExist] - __required_properties__: Tuple[str, ...] - __all_properties__: Tuple[str, Any] - __all_aliases__: Tuple[str, Any] - __all_relationships__: Tuple[str, Any] + __required_properties__: tuple[str, ...] + __all_properties__: tuple[tuple[str, Any], ...] + __all_aliases__: tuple[tuple[str, Any], ...] + __all_relationships__: tuple[tuple[str, Any], ...] __label__: str __optional_labels__: list[str] + defined_properties: Callable[..., dict[str, Any]] + def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) cls = super().__new__(mcs, name, bases, namespace) @@ -1437,7 +1437,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def __eq__(self, other: AsyncStructuredNode | Any) -> bool: + def __eq__(self, other: Any) -> bool: """ Compare two node objects. If both nodes were saved to the database, compare them by their element_id. diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 29f039ec..884fe65c 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re import string from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any from typing import Optional as TOptional from typing import Tuple, Union @@ -321,7 +321,7 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: return property_obj, operator, prop -def process_filter_args(cls, kwargs) -> Dict: +def process_filter_args(cls, kwargs) -> dict: """ loop through properties in filter parameters check they match class definition deflate them and convert into something easy to generate cypher from @@ -369,32 +369,32 @@ def process_has_args(cls, kwargs): class QueryAST: - match: List[str] - optional_match: List[str] - where: List[str] + match: list[str] + optional_match: list[str] + where: list[str] with_clause: TOptional[str] return_clause: TOptional[str] - order_by: TOptional[List[str]] + order_by: TOptional[list[str]] skip: TOptional[int] limit: TOptional[int] result_class: TOptional[type] lookup: TOptional[str] - additional_return: TOptional[List[str]] + additional_return: TOptional[list[str]] is_count: TOptional[bool] def __init__( self, - match: TOptional[List[str]] = None, - optional_match: TOptional[List[str]] = None, - where: TOptional[List[str]] = None, + match: TOptional[list[str]] = None, + optional_match: TOptional[list[str]] = None, + where: TOptional[list[str]] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, - order_by: TOptional[List[str]] = None, + order_by: TOptional[list[str]] = None, skip: TOptional[int] = None, limit: TOptional[int] = None, result_class: TOptional[type] = None, lookup: TOptional[str] = None, - additional_return: TOptional[List[str]] = None, + additional_return: TOptional[list[str]] = None, is_count: TOptional[bool] = False, ) -> None: self.match = match if match else [] @@ -407,19 +407,19 @@ def __init__( self.limit = limit self.result_class = result_class self.lookup = lookup - self.additional_return: List[str] = ( + self.additional_return: list[str] = ( additional_return if additional_return else [] ) self.is_count = is_count - self.subgraph: Dict = {} + self.subgraph: dict = {} class AsyncQueryBuilder: def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() - self._query_params: Dict = {} - self._place_holder_registry: Dict = {} + self._query_params: dict = {} + self._place_holder_registry: dict = {} self._ident_count: int = 0 self._subquery_context: bool = subquery_context @@ -720,7 +720,7 @@ def _finalize_filter_statement( return statement def _build_filter_statements( - self, ident: str, filters, target: List[str], source_class + self, ident: str, filters, target: list[str], source_class ) -> None: for prop, op_and_val in filters.items(): path = None @@ -1213,7 +1213,7 @@ def __post_init__(self): "RawCypher: Do not include any action that has side effect" ) - def render(self, context: Dict) -> str: + def render(self, context: dict) -> str: return string.Template(self.statement).substitute(context) @@ -1236,16 +1236,16 @@ def __init__(self, source) -> None: # setup Traversal objects using relationship definitions install_traversals(self.source_class, self) - self.filters: List = [] + self.filters: list = [] self.q_filters = Q() - self.order_by_elements: List = [] + self.order_by_elements: list = [] # used by has() - self.must_match: Dict = {} - self.dont_match: Dict = {} + self.must_match: dict = {} + self.dont_match: dict = {} - self.relations_to_fetch: List = [] - self._extra_results: List = [] + self.relations_to_fetch: list = [] + self._extra_results: list = [] self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] @@ -1534,7 +1534,7 @@ async def resolve_subgraph(self) -> list: return results async def subquery( - self, nodeset: "AsyncNodeSet", return_set: List[str] + self, nodeset: "AsyncNodeSet", return_set: list[str] ) -> "AsyncNodeSet": """Add a subquery to this node set. @@ -1560,7 +1560,7 @@ async def subquery( def intermediate_transform( self, - vars: Dict[str, Transformation], + vars: dict[str, Transformation], distinct: bool = False, ordering: TOptional[list] = None, ) -> "AsyncNodeSet": @@ -1637,7 +1637,7 @@ def __init__(self, source: Any, name: str, definition: dict) -> None: self.definition = definition self.target_class = definition["node_class"] self.name = name - self.filters: List = [] + self.filters: list = [] def match(self, **kwargs): """ diff --git a/neomodel/async_/path.py b/neomodel/async_/path.py index e04cf202..e1d5eb3e 100644 --- a/neomodel/async_/path.py +++ b/neomodel/async_/path.py @@ -1,4 +1,4 @@ -import typing as t +from collections.abc import Iterator from neo4j.graph import Path @@ -24,8 +24,8 @@ class AsyncNeomodelPath(object): :param nodes: Neomodel nodes appearing in the path in order of appearance. :param relationships: Neomodel relationships appearing in the path in order of appearance. - :type nodes: List[StructuredNode] - :type relationships: List[StructuredRel] + :type nodes: list[StructuredNode] + :type relationships: list[StructuredRel] """ def __init__(self, a_neopath: Path): @@ -56,7 +56,7 @@ def __repr__(self) -> str: def __len__(self) -> int: return len(self._relationships) - def __iter__(self) -> t.Iterator[AsyncStructuredRel]: + def __iter__(self) -> Iterator[AsyncStructuredRel]: return iter(self._relationships) @property diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py index 6eb18b3f..c17bd864 100644 --- a/neomodel/async_/property_manager.py +++ b/neomodel/async_/property_manager.py @@ -1,4 +1,5 @@ import types +from typing import Any from neomodel.exceptions import RequiredProperty from neomodel.properties import AliasProperty, Property @@ -117,7 +118,9 @@ def inflate(cls, graph_entity): return cls(**inflated) @classmethod - def defined_properties(cls, aliases=True, properties=True, rels=True): + def defined_properties( + cls, aliases=True, properties=True, rels=True + ) -> dict[str, Any]: from neomodel.async_.relationship_manager import AsyncRelationshipDefinition props = {} diff --git a/neomodel/scripts/neomodel_generate_diagram.py b/neomodel/scripts/neomodel_generate_diagram.py index 5a81221b..c40f9967 100644 --- a/neomodel/scripts/neomodel_generate_diagram.py +++ b/neomodel/scripts/neomodel_generate_diagram.py @@ -70,7 +70,7 @@ def generate_plantuml(classes): f"{prop}: {parse_property_key(cls.defined_properties(aliases=False, rels=False)[prop])}" for prop in cls.defined_properties(aliases=False, rels=False) ] - label += " \l ".join(properties) + label += r" \l ".join(properties) label += "}}" # Node definition @@ -196,18 +196,22 @@ def generate_arrows_json(classes): "type": rel.definition["relation_type"], "style": {}, "properties": {}, - "fromId": node_id - if ( - isinstance(rel, RelationshipTo) - or isinstance(rel, AsyncRelationshipTo) - ) - else target_id, - "toId": target_id - if ( - isinstance(rel, RelationshipTo) - or isinstance(rel, AsyncRelationshipTo) - ) - else node_id, + "fromId": ( + node_id + if ( + isinstance(rel, RelationshipTo) + or isinstance(rel, AsyncRelationshipTo) + ) + else target_id + ), + "toId": ( + target_id + if ( + isinstance(rel, RelationshipTo) + or isinstance(rel, AsyncRelationshipTo) + ) + else node_id + ), } ) diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index 37ea2918..cb3a7884 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -34,6 +34,7 @@ import string import textwrap from os import environ +from typing import Any from neomodel.sync_.core import db @@ -293,9 +294,9 @@ def inspect_database( print(f"Connecting to {bolt_url}") db.set_connection(bolt_url) - node_labels = get_node_labels() - defined_rel_types = [] - class_definitions = "" + node_labels: list[Any] = get_node_labels() + defined_rel_types: list[str] = [] + class_definitions: str = "" if node_labels: IMPORTS.append("StructuredNode") diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 880de6ca..e0bd4029 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import logging import os import sys @@ -9,7 +7,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Optional, Sequence, Type, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -1304,13 +1302,15 @@ def wrapper(*args, **kwargs): class NodeMeta(type): DoesNotExist: Type[DoesNotExist] - __required_properties__: Tuple[str, ...] - __all_properties__: Tuple[str, Any] - __all_aliases__: Tuple[str, Any] - __all_relationships__: Tuple[str, Any] + __required_properties__: tuple[str, ...] + __all_properties__: tuple[tuple[str, Any], ...] + __all_aliases__: tuple[tuple[str, Any], ...] + __all_relationships__: tuple[tuple[str, Any], ...] __label__: str __optional_labels__: list[str] + defined_properties: Callable[..., dict[str, Any]] + def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) cls = super().__new__(mcs, name, bases, namespace) @@ -1428,7 +1428,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def __eq__(self, other: StructuredNode | Any) -> bool: + def __eq__(self, other: Any) -> bool: """ Compare two node objects. If both nodes were saved to the database, compare them by their element_id. diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 102a5c59..0ae98b5f 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re import string from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any from typing import Optional as TOptional from typing import Tuple, Union @@ -321,7 +321,7 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: return property_obj, operator, prop -def process_filter_args(cls, kwargs) -> Dict: +def process_filter_args(cls, kwargs) -> dict: """ loop through properties in filter parameters check they match class definition deflate them and convert into something easy to generate cypher from @@ -369,32 +369,32 @@ def process_has_args(cls, kwargs): class QueryAST: - match: List[str] - optional_match: List[str] - where: List[str] + match: list[str] + optional_match: list[str] + where: list[str] with_clause: TOptional[str] return_clause: TOptional[str] - order_by: TOptional[List[str]] + order_by: TOptional[list[str]] skip: TOptional[int] limit: TOptional[int] result_class: TOptional[type] lookup: TOptional[str] - additional_return: TOptional[List[str]] + additional_return: TOptional[list[str]] is_count: TOptional[bool] def __init__( self, - match: TOptional[List[str]] = None, - optional_match: TOptional[List[str]] = None, - where: TOptional[List[str]] = None, + match: TOptional[list[str]] = None, + optional_match: TOptional[list[str]] = None, + where: TOptional[list[str]] = None, with_clause: TOptional[str] = None, return_clause: TOptional[str] = None, - order_by: TOptional[List[str]] = None, + order_by: TOptional[list[str]] = None, skip: TOptional[int] = None, limit: TOptional[int] = None, result_class: TOptional[type] = None, lookup: TOptional[str] = None, - additional_return: TOptional[List[str]] = None, + additional_return: TOptional[list[str]] = None, is_count: TOptional[bool] = False, ) -> None: self.match = match if match else [] @@ -407,19 +407,19 @@ def __init__( self.limit = limit self.result_class = result_class self.lookup = lookup - self.additional_return: List[str] = ( + self.additional_return: list[str] = ( additional_return if additional_return else [] ) self.is_count = is_count - self.subgraph: Dict = {} + self.subgraph: dict = {} class QueryBuilder: def __init__(self, node_set, subquery_context: bool = False) -> None: self.node_set = node_set self._ast = QueryAST() - self._query_params: Dict = {} - self._place_holder_registry: Dict = {} + self._query_params: dict = {} + self._place_holder_registry: dict = {} self._ident_count: int = 0 self._subquery_context: bool = subquery_context @@ -720,7 +720,7 @@ def _finalize_filter_statement( return statement def _build_filter_statements( - self, ident: str, filters, target: List[str], source_class + self, ident: str, filters, target: list[str], source_class ) -> None: for prop, op_and_val in filters.items(): path = None @@ -1211,7 +1211,7 @@ def __post_init__(self): "RawCypher: Do not include any action that has side effect" ) - def render(self, context: Dict) -> str: + def render(self, context: dict) -> str: return string.Template(self.statement).substitute(context) @@ -1234,16 +1234,16 @@ def __init__(self, source) -> None: # setup Traversal objects using relationship definitions install_traversals(self.source_class, self) - self.filters: List = [] + self.filters: list = [] self.q_filters = Q() - self.order_by_elements: List = [] + self.order_by_elements: list = [] # used by has() - self.must_match: Dict = {} - self.dont_match: Dict = {} + self.must_match: dict = {} + self.dont_match: dict = {} - self.relations_to_fetch: List = [] - self._extra_results: List = [] + self.relations_to_fetch: list = [] + self._extra_results: list = [] self._subqueries: list[Tuple[str, list[str]]] = [] self._intermediate_transforms: list = [] @@ -1531,7 +1531,7 @@ def resolve_subgraph(self) -> list: ) return results - def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": + def subquery(self, nodeset: "NodeSet", return_set: list[str]) -> "NodeSet": """Add a subquery to this node set. A subquery is a regular cypher query but executed within the context of a CALL @@ -1556,7 +1556,7 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": def intermediate_transform( self, - vars: Dict[str, Transformation], + vars: dict[str, Transformation], distinct: bool = False, ordering: TOptional[list] = None, ) -> "NodeSet": @@ -1633,7 +1633,7 @@ def __init__(self, source: Any, name: str, definition: dict) -> None: self.definition = definition self.target_class = definition["node_class"] self.name = name - self.filters: List = [] + self.filters: list = [] def match(self, **kwargs): """ diff --git a/neomodel/sync_/path.py b/neomodel/sync_/path.py index 90058dd2..b5c45931 100644 --- a/neomodel/sync_/path.py +++ b/neomodel/sync_/path.py @@ -1,4 +1,4 @@ -import typing as t +from collections.abc import Iterator from neo4j.graph import Path @@ -24,8 +24,8 @@ class NeomodelPath(object): :param nodes: Neomodel nodes appearing in the path in order of appearance. :param relationships: Neomodel relationships appearing in the path in order of appearance. - :type nodes: List[StructuredNode] - :type relationships: List[StructuredRel] + :type nodes: list[StructuredNode] + :type relationships: list[StructuredRel] """ def __init__(self, a_neopath: Path): @@ -56,7 +56,7 @@ def __repr__(self) -> str: def __len__(self) -> int: return len(self._relationships) - def __iter__(self) -> t.Iterator[StructuredRel]: + def __iter__(self) -> Iterator[StructuredRel]: return iter(self._relationships) @property diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py index 2b4eeaaa..08f1e900 100644 --- a/neomodel/sync_/property_manager.py +++ b/neomodel/sync_/property_manager.py @@ -1,4 +1,5 @@ import types +from typing import Any from neomodel.exceptions import RequiredProperty from neomodel.properties import AliasProperty, Property @@ -117,7 +118,9 @@ def inflate(cls, graph_entity): return cls(**inflated) @classmethod - def defined_properties(cls, aliases=True, properties=True, rels=True): + def defined_properties( + cls, aliases=True, properties=True, rels=True + ) -> dict[str, Any]: from neomodel.sync_.relationship_manager import RelationshipDefinition props = {} diff --git a/pyproject.toml b/pyproject.toml index 99335e40..1038b969 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,11 @@ dev = [ "pre-commit", "black", "isort", - "Shapely>=2.0.0" + "Shapely>=2.0.0", + "mypy", + "pandas-stubs", + "types-pytz", + "types-shapely" ] pandas = ["pandas"] numpy = ["numpy"] @@ -76,6 +80,11 @@ good-names = 'i,j,k,ex,_,e,fn,x,y,z,id,db,q' max-attributes = 10 max-args = 8 +[tool.mypy] +[[tool.mypy.overrides]] +module = ["neomodel.scripts.*", "neomodel.contrib.spatial_properties"] +ignore_errors = true + [project.scripts] neomodel_install_labels = "neomodel.scripts.neomodel_install_labels:main" neomodel_remove_labels = "neomodel.scripts.neomodel_remove_labels:main" diff --git a/requirements-dev.txt b/requirements-dev.txt index 27a5b002..6b5ec05e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,5 @@ isort Shapely>=2.0.0 mypy>=1.11 pandas-stubs -types-pytz \ No newline at end of file +types-pytz +types-shapely \ No newline at end of file From bbae94dada778b58345157b0aae198c798e35dfc Mon Sep 17 00:00:00 2001 From: MariusC Date: Thu, 5 Dec 2024 14:15:29 +0100 Subject: [PATCH 05/20] Update pyproject.toml Add tonioo in the maintainers --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 99335e40..35dc288a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ authors = [ ] maintainers = [ {name = "Marius Conjeaud", email = "marius.conjeaud@outlook.com"}, + {name = "Antoine Nguyen", email = "tonio@ngyn.org"}, {name = "Athanasios Anastasiou", email = "athanastasiou@gmail.com"}, ] description = "An object mapper for the neo4j graph database." From d88bf35847e7595baf8dddf6a48a9f0b3ed686cb Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 5 Dec 2024 15:42:31 +0100 Subject: [PATCH 06/20] Add rust driver extension and prepare 5.4.2 rc --- .github/workflows/integration-tests.yml | 6 +++++- Changelog | 3 +++ README.md | 8 ++++++-- doc/source/configuration.rst | 2 +- pyproject.toml | 14 ++++++++++---- requirements-dev.txt | 5 ++--- requirements.txt | 2 +- 7 files changed, 28 insertions(+), 12 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7e6f23d0..867a7311 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -31,7 +31,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e '.[dev,pandas,numpy]' + pip install -e '.[dev,extras]' - name: Test with pytest env: AURA_TEST_DB_USER: ${{ secrets.AURA_TEST_DB_USER }} @@ -39,6 +39,10 @@ jobs: AURA_TEST_DB_HOSTNAME: ${{ secrets.AURA_TEST_DB_HOSTNAME }} run: | pytest --cov=neomodel --cov-report=html:coverage_report + - name: Install neo4j-rust-ext and verify it is installed + run: | + pip install -e '.[rust-driver-ext]' + pip list | grep neo4j-rust-ext || exit 1 - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 with: diff --git a/Changelog b/Changelog index 91170523..8bb5c08b 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,6 @@ +Vesion 5.4.2 2024-12 +* Add support for Neo4j Rust driver extension : pip install neomodel['rust-driver-ext'] + Version 5.4.1 2024-11 * Add support for Cypher parallel runtime * Add options for intermediate_transform : distinct, include_in_return, use a prop as source diff --git a/README.md b/README.md index 101e2cba..cf323c73 100644 --- a/README.md +++ b/README.md @@ -65,9 +65,13 @@ Install from pypi (recommended): $ pip install neomodel ($ source dev # To install all things needed in a Python3 venv) - # Neomodel has some optional dependencies (including Shapely), to install these use: + # neomodel can use the Rust extension to the Neo4j driver for faster transport, to install use: - $ pip install neomodel['extras'] + $ pip install neomodel['rust-driver-ext'] + + # neomodel has some optional dependencies (Shapely, pandas, numpy), to install these use: + + $ pip install neomodel['extras, rust-driver-ext'] To install from github: diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index 7c178c29..69f0f2db 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default config.RESOLVER = None # default config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = neomodel/v5.4.1 # default + config.USER_AGENT = neomodel/v5.4.2 # default Setting the database name, if different from the default one:: diff --git a/pyproject.toml b/pyproject.toml index 99335e40..a527036b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,9 @@ classifiers = [ "Topic :: Database", ] dependencies = [ - "neo4j~=5.26.0", + "neo4j~=5.27.0", ] -requires-python = ">=3.8" +requires-python = ">=3.9" dynamic = ["version"] [project.urls] @@ -33,6 +33,11 @@ repository = "http://github.com/neo4j-contrib/neomodel" changelog = "https://github.com/neo4j-contrib/neomodel/releases" [project.optional-dependencies] +extras = [ + "shapely", + "pandas", + "numpy" +] dev = [ "unasync", "pytest>=7.1", @@ -41,11 +46,12 @@ dev = [ "pytest-mock", "pre-commit", "black", - "isort", - "Shapely>=2.0.0" + "isort" ] +shapely = ["Shapely>=2.0.0"] pandas = ["pandas"] numpy = ["numpy"] +rust-driver-ext = ["neo4j-rust-ext==5.27.0.0"] [build-system] requires = ["setuptools>=68"] diff --git a/requirements-dev.txt b/requirements-dev.txt index ad82ba50..e5a19c87 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ # neomodel --e .[pandas,numpy] +-e .[extras] unasync>=0.5.0 pytest>=7.1 @@ -8,5 +8,4 @@ pytest-cov>=4.0 pytest-mock pre-commit black -isort -Shapely>=2.0.0 \ No newline at end of file +isort \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e7a3f522..e068f61c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -neo4j~=5.26.0 +neo4j~=5.27.0 From ce7f71848cbbb75837361b6bef539e749a67cd2a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 5 Dec 2024 15:50:10 +0100 Subject: [PATCH 07/20] Simply pip install in README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cf323c73..c7dc01d0 100644 --- a/README.md +++ b/README.md @@ -67,11 +67,11 @@ Install from pypi (recommended): # neomodel can use the Rust extension to the Neo4j driver for faster transport, to install use: - $ pip install neomodel['rust-driver-ext'] + $ pip install neomodel[rust-driver-ext] # neomodel has some optional dependencies (Shapely, pandas, numpy), to install these use: - $ pip install neomodel['extras, rust-driver-ext'] + $ pip install neomodel[extras, rust-driver-ext] To install from github: From 4941d1d357c29efc4ed5f0255dc6aa5aa86c984c Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 5 Dec 2024 15:50:44 +0100 Subject: [PATCH 08/20] Update changelog --- Changelog | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Changelog b/Changelog index 8bb5c08b..4f78f667 100644 --- a/Changelog +++ b/Changelog @@ -1,5 +1,5 @@ Vesion 5.4.2 2024-12 -* Add support for Neo4j Rust driver extension : pip install neomodel['rust-driver-ext'] +* Add support for Neo4j Rust driver extension : pip install neomodel[rust-driver-ext] Version 5.4.1 2024-11 * Add support for Cypher parallel runtime From 820556f633bf6dadfa9bd5a554ffe51eadd3d4d3 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 10 Dec 2024 15:33:55 +0100 Subject: [PATCH 09/20] Various improvements about subqueries. --- doc/source/advanced_query_operations.rst | 14 +++-- neomodel/async_/match.py | 75 +++++++++++++++++++----- neomodel/sync_/match.py | 75 +++++++++++++++++++----- neomodel/typing.py | 13 +++- test/async_/test_match_api.py | 29 +++++++++ test/sync_/test_match_api.py | 29 +++++++++ 6 files changed, 198 insertions(+), 37 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index 73c5bbd6..de1c8c61 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -60,7 +60,7 @@ As discussed in the note above, this is for example useful when you need to orde Options for `intermediate_transform` *variables* are: -- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below). +- `source`: `string` or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below). - `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation. - `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False. @@ -95,7 +95,7 @@ Subqueries The `subquery` method allows you to perform a `Cypher subquery `_ inside your query. This allows you to perform operations in isolation to the rest of your query:: from neomodel.sync_match import Collect, Last - + # This will create a CALL{} subquery # And return a variable named supps usable in the rest of your query Coffee.nodes.filter(name="Espresso") @@ -106,12 +106,18 @@ The `subquery` method allows you to perform a `Cypher subquery None: + def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: Dict = {} self._place_holder_registry: Dict = {} self._ident_count: int = 0 - self._subquery_context: bool = subquery_context + self._subquery_namespace: TOptional[str] = subquery_namespace async def build_ast(self) -> "AsyncQueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -558,7 +557,7 @@ def build_traversal_from_path( # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - if self._subquery_context: + if self._subquery_namespace: # Don't include label in identifier if we are in a subquery lhs_ident = lhs_name elif relation["include_in_return"]: @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str: self._place_holder_registry[key] += 1 else: self._place_holder_registry[key] = 1 - return key + "_" + str(self._place_holder_registry[key]) + place_holder = f"{key}_{self._place_holder_registry[key]}" + if self._subquery_namespace: + place_holder = f"{self._subquery_namespace}_{place_holder}" + return place_holder def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop @@ -879,10 +881,21 @@ def build_query(self) -> str: query += ",".join(ordering) if hasattr(self.node_set, "_subqueries"): - for subquery, return_set in self.node_set._subqueries: - outer_primary_var = self._ast.return_clause - query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " - for varname in return_set: + for subquery in self.node_set._subqueries: + query += " CALL {" + if subquery["initial_context"]: + query += " WITH " + context: List[str] = [] + for var in subquery["initial_context"]: + if isinstance(var, (NodeNameResolver, RelationNameResolver)): + context.append(var.resolve(self)) + else: + context.append(var) + query += ",".join(context) + + query += f"{subquery['query']} }} " + self._query_params.update(subquery["query_params"]) + for varname in subquery["return_set"]: # We declare the returned variables as "virtual" relations of the # root node class to make sure they will be translated by a call to # resolve_subgraph() (otherwise, they will be lost). @@ -893,10 +906,10 @@ def build_query(self) -> str: "variable_name": varname, "rel_variable_name": varname, } - returned_items += return_set + returned_items += subquery["return_set"] query += " RETURN " - if self._ast.return_clause and not self._subquery_context: + if self._ast.return_clause and not self._subquery_namespace: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -1120,6 +1133,8 @@ class NodeNameResolver: node: str def resolve(self, qbuilder: AsyncQueryBuilder) -> str: + if self.node == "self" and qbuilder._ast.return_clause: + return qbuilder._ast.return_clause result = qbuilder.lookup_query_variable(self.node) if result is None: raise ValueError(f"Unable to resolve variable name for node {self.node}") @@ -1238,7 +1253,7 @@ def __init__(self, source) -> None: self.relations_to_fetch: List = [] self._extra_results: List = [] - self._subqueries: list[Tuple[str, list[str]]] = [] + self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] def __await__(self): @@ -1525,7 +1540,10 @@ async def resolve_subgraph(self) -> list: return results async def subquery( - self, nodeset: "AsyncNodeSet", return_set: List[str] + self, + nodeset: "AsyncNodeSet", + return_set: List[str], + initial_context: TOptional[List[str]] = None, ) -> "AsyncNodeSet": """Add a subquery to this node set. @@ -1534,16 +1552,41 @@ async def subquery( declared inside return_set variable in order to be included in the final RETURN statement. """ - qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast() + namespace = f"sq{len(self._subqueries) + 1}" + qbuilder = await nodeset.query_cls( + nodeset, subquery_namespace=namespace + ).build_ast() for var in return_set: if ( var != qbuilder._ast.return_clause and var not in qbuilder._ast.additional_return and var not in [res["alias"] for res in nodeset._extra_results if res["alias"]] + and var + not in [ + varname + for tr in nodeset._intermediate_transforms + for varname, vardef in tr["vars"].items() + if vardef.get("include_in_return") + ] ): raise RuntimeError(f"Variable '{var}' is not returned by subquery.") - self._subqueries.append((qbuilder.build_query(), return_set)) + if initial_context: + for var in initial_context: + if type(var) is not str and not isinstance( + var, (NodeNameResolver, RelationNameResolver, RawCypher) + ): + raise ValueError( + f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver" + ) + self._subqueries.append( + { + "query": qbuilder.build_query(), + "query_params": qbuilder._query_params, + "return_set": return_set, + "initial_context": initial_context, + } + ) return self def intermediate_transform( diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 15a49cfb..b26714ee 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,7 +1,6 @@ import inspect import re import string -import warnings from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -13,7 +12,7 @@ from neomodel.sync_ import relationship_manager from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.relationship import StructuredRel -from neomodel.typing import Transformation +from neomodel.typing import Subquery, Transformation from neomodel.util import INCOMING, OUTGOING CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") @@ -414,13 +413,13 @@ def __init__( class QueryBuilder: - def __init__(self, node_set, subquery_context: bool = False) -> None: + def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: Dict = {} self._place_holder_registry: Dict = {} self._ident_count: int = 0 - self._subquery_context: bool = subquery_context + self._subquery_namespace: TOptional[str] = subquery_namespace def build_ast(self) -> "QueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -558,7 +557,7 @@ def build_traversal_from_path( # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - if self._subquery_context: + if self._subquery_namespace: # Don't include label in identifier if we are in a subquery lhs_ident = lhs_name elif relation["include_in_return"]: @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str: self._place_holder_registry[key] += 1 else: self._place_holder_registry[key] = 1 - return key + "_" + str(self._place_holder_registry[key]) + place_holder = f"{key}_{self._place_holder_registry[key]}" + if self._subquery_namespace: + place_holder = f"{self._subquery_namespace}_{place_holder}" + return place_holder def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop @@ -879,10 +881,21 @@ def build_query(self) -> str: query += ",".join(ordering) if hasattr(self.node_set, "_subqueries"): - for subquery, return_set in self.node_set._subqueries: - outer_primary_var = self._ast.return_clause - query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " - for varname in return_set: + for subquery in self.node_set._subqueries: + query += " CALL {" + if subquery["initial_context"]: + query += " WITH " + context: List[str] = [] + for var in subquery["initial_context"]: + if isinstance(var, (NodeNameResolver, RelationNameResolver)): + context.append(var.resolve(self)) + else: + context.append(var) + query += ",".join(context) + + query += f"{subquery['query']} }} " + self._query_params.update(subquery["query_params"]) + for varname in subquery["return_set"]: # We declare the returned variables as "virtual" relations of the # root node class to make sure they will be translated by a call to # resolve_subgraph() (otherwise, they will be lost). @@ -893,10 +906,10 @@ def build_query(self) -> str: "variable_name": varname, "rel_variable_name": varname, } - returned_items += return_set + returned_items += subquery["return_set"] query += " RETURN " - if self._ast.return_clause and not self._subquery_context: + if self._ast.return_clause and not self._subquery_namespace: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -1118,6 +1131,8 @@ class NodeNameResolver: node: str def resolve(self, qbuilder: QueryBuilder) -> str: + if self.node == "self" and qbuilder._ast.return_clause: + return qbuilder._ast.return_clause result = qbuilder.lookup_query_variable(self.node) if result is None: raise ValueError(f"Unable to resolve variable name for node {self.node}") @@ -1236,7 +1251,7 @@ def __init__(self, source) -> None: self.relations_to_fetch: List = [] self._extra_results: List = [] - self._subqueries: list[Tuple[str, list[str]]] = [] + self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] def __await__(self): @@ -1522,7 +1537,12 @@ def resolve_subgraph(self) -> list: ) return results - def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": + def subquery( + self, + nodeset: "NodeSet", + return_set: List[str], + initial_context: TOptional[List[str]] = None, + ) -> "NodeSet": """Add a subquery to this node set. A subquery is a regular cypher query but executed within the context of a CALL @@ -1530,16 +1550,39 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": declared inside return_set variable in order to be included in the final RETURN statement. """ - qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast() + namespace = f"sq{len(self._subqueries) + 1}" + qbuilder = nodeset.query_cls(nodeset, subquery_namespace=namespace).build_ast() for var in return_set: if ( var != qbuilder._ast.return_clause and var not in qbuilder._ast.additional_return and var not in [res["alias"] for res in nodeset._extra_results if res["alias"]] + and var + not in [ + varname + for tr in nodeset._intermediate_transforms + for varname, vardef in tr["vars"].items() + if vardef.get("include_in_return") + ] ): raise RuntimeError(f"Variable '{var}' is not returned by subquery.") - self._subqueries.append((qbuilder.build_query(), return_set)) + if initial_context: + for var in initial_context: + if type(var) is not str and not isinstance( + var, (NodeNameResolver, RelationNameResolver, RawCypher) + ): + raise ValueError( + f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver" + ) + self._subqueries.append( + { + "query": qbuilder.build_query(), + "query_params": qbuilder._query_params, + "return_set": return_set, + "initial_context": initial_context, + } + ) return self def intermediate_transform( diff --git a/neomodel/typing.py b/neomodel/typing.py index 9438bd54..f0558096 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -1,6 +1,6 @@ """Custom types used for annotations.""" -from typing import Any, Optional, TypedDict +from typing import Any, Dict, List, Optional, TypedDict Transformation = TypedDict( "Transformation", @@ -10,3 +10,14 @@ "include_in_return": Optional[bool], }, ) + + +Subquery = TypedDict( + "Subquery", + { + "query": str, + "query_params": Dict, + "return_set": List[str], + "initial_context": Optional[List[Any]], + }, +) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 2dff91c0..a494ae42 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -887,6 +887,7 @@ async def test_subquery(): ) .annotate(supps=Last(Collect("suppliers"))), ["supps"], + [NodeNameResolver("self")], ) result = await result.all() assert len(result) == 1 @@ -905,6 +906,34 @@ async def test_subquery(): ) +@mark_async_test +async def test_subquery_other_node(): + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1) + await nescafe.suppliers.connect(supplier2) + await nescafe.species.connect(arabica) + + result = await Coffee.nodes.subquery( + Supplier.nodes.filter(name="Supplier 2").intermediate_transform( + { + "cost": { + "source": "supplier", + "source_prop": "delivery_cost", + "include_in_return": True, + } + } + ), + ["cost"], + ) + result = await result.all() + assert len(result) == 1 + assert result[0][0] == 20 + + @mark_async_test async def test_intermediate_transform(): arabica = await Species(name="Arabica").save() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 4df51866..0bf69b7f 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -871,6 +871,7 @@ def test_subquery(): ) .annotate(supps=Last(Collect("suppliers"))), ["supps"], + [NodeNameResolver("self")], ) result = result.all() assert len(result) == 1 @@ -889,6 +890,34 @@ def test_subquery(): ) +@mark_sync_test +def test_subquery_other_node(): + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1) + nescafe.suppliers.connect(supplier2) + nescafe.species.connect(arabica) + + result = Coffee.nodes.subquery( + Supplier.nodes.filter(name="Supplier 2").intermediate_transform( + { + "cost": { + "source": "supplier", + "source_prop": "delivery_cost", + "include_in_return": True, + } + } + ), + ["cost"], + ) + result = result.all() + assert len(result) == 1 + assert result[0][0] == 20 + + @mark_sync_test def test_intermediate_transform(): arabica = Species(name="Arabica").save() From 058f5b8f9691af1b69b22cce28625f4210d800f5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 10 Dec 2024 16:47:43 +0100 Subject: [PATCH 10/20] Increase test coverage --- doc/source/advanced_query_operations.rst | 2 +- neomodel/_version.py | 2 +- neomodel/typing.py | 8 +++---- test/async_/test_match_api.py | 29 ++++++++++++++++++++++-- test/sync_/test_match_api.py | 29 ++++++++++++++++++++++-- 5 files changed, 60 insertions(+), 10 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index de1c8c61..74c15683 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -117,7 +117,7 @@ Options for `subquery` calls are: .. note:: In the example above, we reference `self` to be included in the initial context. It will actually inject the outer variable corresponding to `Coffee` node. - We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know. + We know this is confusing to read, but have not found a better way to do this yet. If you have any suggestions, please let us know. Helpers ------- diff --git a/neomodel/_version.py b/neomodel/_version.py index 1e41bf8f..cfda0f8e 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.4.1" +__version__ = "5.4.2" diff --git a/neomodel/typing.py b/neomodel/typing.py index f0558096..a23f88eb 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -1,6 +1,6 @@ """Custom types used for annotations.""" -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Optional, TypedDict Transformation = TypedDict( "Transformation", @@ -16,8 +16,8 @@ "Subquery", { "query": str, - "query_params": Dict, - "return_set": List[str], - "initial_context": Optional[List[Any]], + "query_params": dict, + "return_set": list[str], + "initial_context": Optional[list[Any]], }, ) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index a494ae42..70c7f351 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -2,6 +2,7 @@ from datetime import datetime from test._async_compat import mark_async_test +import numpy as np from pytest import raises, skip, warns from neomodel import ( @@ -880,7 +881,7 @@ async def test_subquery(): await nescafe.suppliers.connect(supplier2) await nescafe.species.connect(arabica) - result = await Coffee.nodes.subquery( + subquery = await Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] @@ -889,7 +890,7 @@ async def test_subquery(): ["supps"], [NodeNameResolver("self")], ) - result = await result.all() + result = await subquery.all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier2 @@ -905,6 +906,30 @@ async def test_subquery(): ["unknown"], ) + result_string_context = await subquery.subquery( + Coffee.nodes.traverse_relations(supps2="suppliers").annotate( + supps2=Collect("supps") + ), + ["supps2"], + ["supps"], + ) + result_string_context = await result_string_context.all() + assert len(result) == 1 + additional_elements = [ + item for item in result_string_context[0] if item not in result[0] + ] + assert len(additional_elements) == 1 + assert isinstance(additional_elements[0], list) + + with raises(ValueError, match=r"Wrong variable specified in initial context"): + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + [2], + ) + @mark_async_test async def test_subquery_other_node(): diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 0bf69b7f..94465db2 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -2,6 +2,7 @@ from datetime import datetime from test._async_compat import mark_sync_test +import numpy as np from pytest import raises, skip, warns from neomodel import ( @@ -864,7 +865,7 @@ def test_subquery(): nescafe.suppliers.connect(supplier2) nescafe.species.connect(arabica) - result = Coffee.nodes.subquery( + subquery = Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] @@ -873,7 +874,7 @@ def test_subquery(): ["supps"], [NodeNameResolver("self")], ) - result = result.all() + result = subquery.all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier2 @@ -889,6 +890,30 @@ def test_subquery(): ["unknown"], ) + result_string_context = subquery.subquery( + Coffee.nodes.traverse_relations(supps2="suppliers").annotate( + supps2=Collect("supps") + ), + ["supps2"], + ["supps"], + ) + result_string_context = result_string_context.all() + assert len(result) == 1 + additional_elements = [ + item for item in result_string_context[0] if item not in result[0] + ] + assert len(additional_elements) == 1 + assert isinstance(additional_elements[0], list) + + with raises(ValueError, match=r"Wrong variable specified in initial context"): + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + [2], + ) + @mark_sync_test def test_subquery_other_node(): From 6c8a1f16b7a2be9b0d5aab55134dda4cb8d0d689 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 10 Dec 2024 16:58:53 +0100 Subject: [PATCH 11/20] Update changelog --- Changelog | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Changelog b/Changelog index 4f78f667..a93e806e 100644 --- a/Changelog +++ b/Changelog @@ -1,5 +1,7 @@ Vesion 5.4.2 2024-12 * Add support for Neo4j Rust driver extension : pip install neomodel[rust-driver-ext] +* Add initial_context parameter to subqueries +* NodeNameResolver can call self to reference top-level node Version 5.4.1 2024-11 * Add support for Cypher parallel runtime From bd17f85906c6f66319036b70a84f7d07cb06dcb6 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 11 Dec 2024 14:20:00 +0100 Subject: [PATCH 12/20] Add more type hints and fix issues --- neomodel/exceptions.py | 68 +++++++++--------- neomodel/properties.py | 155 ++++++++++++++++++++++++----------------- neomodel/util.py | 26 ++++--- 3 files changed, 139 insertions(+), 110 deletions(-) diff --git a/neomodel/exceptions.py b/neomodel/exceptions.py index 3959c347..45b24e0e 100644 --- a/neomodel/exceptions.py +++ b/neomodel/exceptions.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Any, Optional, Type, Union class NeomodelException(Exception): @@ -26,11 +26,11 @@ class CardinalityViolation(NeomodelException): For example a relationship type `OneOrMore` returns no nodes. """ - def __init__(self, rel_manager, actual): + def __init__(self, rel_manager: Any, actual: Union[int, str]): self.rel_manager = str(rel_manager) self.actual = str(actual) - def __str__(self): + def __str__(self) -> str: return ( f"CardinalityViolation: Expected: {self.rel_manager}, got: {self.actual}." ) @@ -43,9 +43,9 @@ class ModelDefinitionException(NeomodelException): def __init__( self, - db_node_rel_class, - current_node_class_registry, - current_db_specific_node_class_registry, + db_node_rel_class: Any, + current_node_class_registry: dict[frozenset, Any], + current_db_specific_node_class_registry: dict[str, dict], ): """ Initialises the exception with the database node that caused the missmatch. @@ -63,7 +63,7 @@ def __init__( current_db_specific_node_class_registry ) - def _get_node_class_registry_formatted(self): + def _get_node_class_registry_formatted(self) -> str: """ Returns the current node class registry string formatted as a list of Labels --> entries. @@ -99,7 +99,7 @@ class NodeClassNotDefined(ModelDefinitionException): In either of these cases the mismatch must be reported """ - def __str__(self): + def __str__(self) -> str: node_labels = ",".join(self.db_node_rel_class.labels) return f"Node with labels {node_labels} does not resolve to any of the known objects\n{self._get_node_class_registry_formatted()}\n" @@ -111,7 +111,7 @@ class RelationshipClassNotDefined(ModelDefinitionException): a data model object. """ - def __str__(self): + def __str__(self) -> str: relationship_type = self.db_node_rel_class.type return f""" Relationship of type {relationship_type} does not resolve to any of the known objects @@ -128,10 +128,10 @@ class RelationshipClassRedefined(ModelDefinitionException): def __init__( self, - db_rel_class_type, - current_node_class_registry, - current_db_specific_node_class_registry, - remapping_to_class, + db_rel_class_type: Any, + current_node_class_registry: dict[frozenset, Any], + current_db_specific_node_class_registry: dict[str, dict], + remapping_to_class: Any, ): """ Initialises a relationship redefinition exception with the required data as follows: @@ -151,7 +151,7 @@ def __init__( ) self.remapping_to_class = remapping_to_class - def __str__(self): + def __str__(self) -> str: relationship_type = self.db_node_rel_class return f"Relationship of type {relationship_type} redefined as {self.remapping_to_class}.\n{self._get_node_class_registry_formatted()}\n" @@ -162,25 +162,25 @@ class NodeClassAlreadyDefined(ModelDefinitionException): that already has a mapping within the node-to-class registry. """ - def __str__(self): + def __str__(self) -> str: node_class_labels = ",".join(self.db_node_rel_class.inherited_labels()) return f"Class {self.db_node_rel_class.__module__}.{self.db_node_rel_class.__name__} with labels {node_class_labels} already defined:\n{self._get_node_class_registry_formatted()}\n" class ConstraintValidationFailed(ValueError, NeomodelException): - def __init__(self, msg): + def __init__(self, msg: str): self.message = msg class DeflateError(ValueError, NeomodelException): - def __init__(self, key, cls, msg, obj): + def __init__(self, key: str, cls: Any, msg: str, obj: Any): self.property_name = key self.node_class = cls self.msg = msg self.obj = repr(obj) - def __str__(self): + def __str__(self) -> str: return f"Attempting to deflate property '{self.property_name}' on {self.obj} of class '{self.node_class.__name__}': {self.msg}" @@ -191,84 +191,84 @@ class DoesNotExist(NeomodelException): belongs to. It is set by :class:`~neomodel.core.NodeMeta`. """ - def __init__(self, msg): + def __init__(self, msg: str): if self._model_class is None: raise RuntimeError("This class hasn't been setup properly.") self.message: str = msg super().__init__(self, msg) - def __reduce__(self): + def __reduce__(self) -> tuple: return _unpickle_does_not_exist, (self._model_class, self.message) -def _unpickle_does_not_exist(_model_class, message): +def _unpickle_does_not_exist(_model_class: Any, message: str) -> DoesNotExist: return _model_class.DoesNotExist(message) class InflateConflict(NeomodelException): - def __init__(self, cls, key, value, nid): + def __init__(self, cls: Any, key: str, value: Any, nid: str): self.cls_name = cls.__name__ self.property_name = key self.value = value self.nid = nid - def __str__(self): + def __str__(self) -> str: return f"Found conflict with node {self.nid}, has property '{self.property_name}' with value '{self.value}' although class {self.cls_name} already has a property '{self.property_name}'" class InflateError(ValueError, NeomodelException): - def __init__(self, key, cls, msg, obj=None): + def __init__(self, key: str, cls: Any, msg: str, obj: Optional[Any] = None): self.property_name = key self.node_class = cls self.msg = msg self.obj = repr(obj) - def __str__(self): + def __str__(self) -> str: return f"Attempting to inflate property '{self.property_name}' on {self.obj} of class '{self.node_class.__name__}': {self.msg}" class DeflateConflict(InflateConflict): - def __init__(self, cls, key, value, nid): + def __init__(self, cls: Any, key: str, value: Any, nid: str): self.cls_name = cls.__name__ self.property_name = key self.value = value self.nid = nid if nid else "(unsaved)" - def __str__(self): + def __str__(self) -> str: return f"Found trying to set property '{self.property_name}' with value '{self.value}' on node {self.nid} although class {self.cls_name} already has a property '{self.property_name}'" class MultipleNodesReturned(ValueError, NeomodelException): - def __init__(self, msg): + def __init__(self, msg: str): self.message = msg class NotConnected(NeomodelException): - def __init__(self, action, node1, node2): + def __init__(self, action: str, node1: Any, node2: Any): self.action = action self.node1 = node1 self.node2 = node2 - def __str__(self): + def __str__(self) -> str: return f"Error performing '{self.action}' - Node {self.node1.element_id} of type '{self.node1.__class__.__name__}' is not connected to {self.node2.element_id} of type '{self.node2.__class__.__name__}'." class RequiredProperty(NeomodelException): - def __init__(self, key, cls): + def __init__(self, key: str, cls: Any): self.property_name = key self.node_class = cls - def __str__(self): + def __str__(self) -> str: return f"property '{self.property_name}' on objects of class {self.node_class.__name__}" class UniqueProperty(ConstraintValidationFailed): - def __init__(self, msg): + def __init__(self, msg: str): self.message = msg class FeatureNotSupported(NeomodelException): - def __init__(self, msg): + def __init__(self, msg: str): self.message = msg diff --git a/neomodel/properties.py b/neomodel/properties.py index a91b29cb..27da64da 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -4,7 +4,7 @@ import uuid from abc import ABCMeta, abstractmethod from datetime import date, datetime -from typing import Any, Optional +from typing import Any, Callable, Optional import neo4j.time import pytz @@ -15,13 +15,15 @@ TOO_MANY_DEFAULTS = "too many defaults" -def validator(fn): +def validator(fn: Callable) -> Callable: fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ if fn_name not in ["inflate", "deflate"]: raise ValueError("Unknown Property method " + fn_name) @functools.wraps(fn) - def _validator(self, value, obj=None, rethrow=True): + def _validator( # type: ignore + self, value: Any, obj: Optional[Any] = None, rethrow: Optional[bool] = True + ) -> Any: if rethrow: try: return fn(self, value) @@ -46,8 +48,8 @@ class FulltextIndex: def __init__( self, - analyzer="standard-no-stop-words", - eventually_consistent=False, + analyzer: Optional[str] = "standard-no-stop-words", + eventually_consistent: Optional[bool] = False, ): """ Initializes new fulltext index definition with analyzer and eventually consistent @@ -64,7 +66,11 @@ class VectorIndex: Vector index definition """ - def __init__(self, dimensions=1536, similarity_function="cosine"): + def __init__( + self, + dimensions: Optional[int] = 1536, + similarity_function: Optional[str] = "cosine", + ): """ Initializes new vector index definition with dimensions and similarity @@ -119,16 +125,16 @@ def __init__( self, name: Optional[str] = None, owner: Optional[Any] = None, - unique_index=False, - index=False, + unique_index: bool = False, + index: bool = False, fulltext_index: Optional[FulltextIndex] = None, vector_index: Optional[VectorIndex] = None, - required=False, - default=None, - db_property=None, - label=None, - help_text=None, - **kwargs, + required: bool = False, + default: Optional[Any] = None, + db_property: Optional[str] = None, + label: Optional[str] = None, + help_text: Optional[str] = None, + **kwargs: dict[str, Any], ): if default is not None and required: raise ValueError( @@ -150,7 +156,7 @@ def __init__( self.label = label self.help_text = help_text - def default_value(self): + def default_value(self) -> Any: """ Generate a default value @@ -162,7 +168,7 @@ def default_value(self): return self.default raise ValueError("No default value specified") - def get_db_property_name(self, attribute_name): + def get_db_property_name(self, attribute_name: str) -> str: """ Returns the name that should be used for the property in the database. This is db_property if supplied upon construction, otherwise the given attribute_name from the model is used. @@ -170,11 +176,15 @@ def get_db_property_name(self, attribute_name): return self.db_property or attribute_name @property - def is_indexed(self): + def is_indexed(self) -> bool: return self.unique_index or self.index @abstractmethod - def deflate(self, value: Any) -> Any: + def inflate(self, value: Any, rethrow: bool) -> Any: + pass + + @abstractmethod + def deflate(self, value: Any, rethrow: bool) -> Any: pass @@ -185,18 +195,18 @@ class NormalizedProperty(Property): """ @validator - def inflate(self, value): + def inflate(self, value: Any) -> Any: return self.normalize(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> Any: return self.normalize(value) - def default_value(self): + def default_value(self) -> Any: default = super().default_value() return self.normalize(default) - def normalize(self, value): + def normalize(self, value: Any) -> Any: raise NotImplementedError("Specialize normalize method") @@ -213,7 +223,7 @@ class RegexProperty(NormalizedProperty): expression: str - def __init__(self, expression=None, **kwargs): + def __init__(self, expression: Optional[str] = None, **kwargs: Any): """ Initializes new property with an expression. @@ -222,7 +232,7 @@ def __init__(self, expression=None, **kwargs): super().__init__(**kwargs) self.expression = expression or self.expression - def normalize(self, value): + def normalize(self, value: Any) -> str: normal = str(value) if not re.match(self.expression, normal): raise ValueError(f"{value!r} does not match {self.expression!r}") @@ -250,7 +260,12 @@ class StringProperty(NormalizedProperty): :type max_length: int """ - def __init__(self, choices=None, max_length=None, **kwargs): + def __init__( + self, + choices: Optional[Any] = None, + max_length: Optional[int] = None, + **kwargs: Any, + ): if max_length is not None: if choices is not None: raise ValueError( @@ -273,7 +288,7 @@ def __init__(self, choices=None, max_length=None, **kwargs): ) from exc self.form_field_class = "TypedChoiceField" - def normalize(self, value): + def normalize(self, value: str) -> str: # One thing to note here is that the following two checks can remain uncoupled # as long as it is guaranteed (by the constructor) that `choices` and `max_length` # are mutually exclusive. If that check in the constructor ever has to be removed, @@ -287,7 +302,7 @@ def normalize(self, value): ) return str(value) - def default_value(self): + def default_value(self) -> str: return self.normalize(super().default_value()) @@ -299,14 +314,14 @@ class IntegerProperty(Property): form_field_class = "IntegerField" @validator - def inflate(self, value): + def inflate(self, value: Any) -> int: return int(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> int: return int(value) - def default_value(self): + def default_value(self) -> int: return int(super().default_value()) @@ -315,7 +330,7 @@ class ArrayProperty(Property): Stores a list of items """ - def __init__(self, base_property=None, **kwargs): + def __init__(self, base_property: Optional[Property] = None, **kwargs: Any): """ Store a list of values, optionally of a specific type. @@ -347,20 +362,20 @@ def __init__(self, base_property=None, **kwargs): super().__init__(**kwargs) @validator - def inflate(self, value): + def inflate(self, value: Any) -> list: if self.base_property: return [self.base_property.inflate(item, rethrow=False) for item in value] return list(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> list: if self.base_property: return [self.base_property.deflate(item, rethrow=False) for item in value] return list(value) - def default_value(self): + def default_value(self) -> list: return list(super().default_value()) @@ -372,14 +387,14 @@ class FloatProperty(Property): form_field_class = "FloatField" @validator - def inflate(self, value): + def inflate(self, value: Any) -> float: return float(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> float: return float(value) - def default_value(self): + def default_value(self) -> float: return float(super().default_value()) @@ -391,14 +406,14 @@ class BooleanProperty(Property): form_field_class = "BooleanField" @validator - def inflate(self, value): + def inflate(self, value: Any) -> bool: return bool(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> bool: return bool(value) - def default_value(self): + def default_value(self) -> bool: return bool(super().default_value()) @@ -410,7 +425,7 @@ class DateProperty(Property): form_field_class = "DateField" @validator - def inflate(self, value): + def inflate(self, value: Any) -> date: if isinstance(value, neo4j.time.DateTime): value = date(value.year, value.month, value.day) elif isinstance(value, str) and "T" in value: @@ -418,7 +433,7 @@ def inflate(self, value): return datetime.strptime(str(value), "%Y-%m-%d").date() @validator - def deflate(self, value): + def deflate(self, value: date) -> str: if not isinstance(value, date): msg = f"datetime.date object expected, got {repr(value)}" raise ValueError(msg) @@ -438,7 +453,9 @@ class DateTimeFormatProperty(Property): form_field_class = "DateTimeFormatField" - def __init__(self, default_now=False, format="%Y-%m-%d", **kwargs): + def __init__( + self, default_now: bool = False, format: str = "%Y-%m-%d", **kwargs: Any + ): if default_now: if "default" in kwargs: raise ValueError(TOO_MANY_DEFAULTS) @@ -448,11 +465,11 @@ def __init__(self, default_now=False, format="%Y-%m-%d", **kwargs): super().__init__(**kwargs) @validator - def inflate(self, value): + def inflate(self, value: Any) -> datetime: return datetime.strptime(str(value), self.format) @validator - def deflate(self, value): + def deflate(self, value: datetime) -> str: if not isinstance(value, datetime): raise ValueError(f"datetime object expected, got {type(value)}.") return datetime.strftime(value, self.format) @@ -469,16 +486,16 @@ class DateTimeProperty(Property): form_field_class = "DateTimeField" - def __init__(self, default_now=False, **kwargs): + def __init__(self, default_now: bool = False, **kwargs: Any): if default_now: if "default" in kwargs: raise ValueError(TOO_MANY_DEFAULTS) - kwargs["default"] = lambda: datetime.utcnow().replace(tzinfo=pytz.utc) + kwargs["default"] = lambda: datetime.now(pytz.utc) super().__init__(**kwargs) @validator - def inflate(self, value): + def inflate(self, value: Any) -> datetime: try: epoch = float(value) except ValueError as exc: @@ -489,10 +506,10 @@ def inflate(self, value): raise TypeError( f"Float or integer expected. Can't inflate {type(value)} to datetime." ) from exc - return datetime.utcfromtimestamp(epoch).replace(tzinfo=pytz.utc) + return datetime.fromtimestamp(epoch, tz=pytz.utc) @validator - def deflate(self, value): + def deflate(self, value: datetime) -> float: if not isinstance(value, datetime): raise ValueError(f"datetime object expected, got {type(value)}.") if value.tzinfo: @@ -518,7 +535,7 @@ class DateTimeNeo4jFormatProperty(Property): form_field_class = "DateTimeNeo4jFormatField" - def __init__(self, default_now=False, **kwargs): + def __init__(self, default_now: bool = False, **kwargs: Any): if default_now: if "default" in kwargs: raise ValueError(TOO_MANY_DEFAULTS) @@ -528,11 +545,11 @@ def __init__(self, default_now=False, **kwargs): super(DateTimeNeo4jFormatProperty, self).__init__(**kwargs) @validator - def inflate(self, value): + def inflate(self, value: Any) -> datetime: return value.to_native() @validator - def deflate(self, value): + def deflate(self, value: datetime) -> neo4j.time.DateTime: if not isinstance(value, datetime): raise ValueError("datetime object expected, got {0}.".format(type(value))) return neo4j.time.DateTime.from_native(value) @@ -545,16 +562,16 @@ class JSONProperty(Property): The structure will be inflated when a node is retrieved. """ - def __init__(self, ensure_ascii=True, *args, **kwargs): + def __init__(self, ensure_ascii: bool = True, *args: Any, **kwargs: Any): self.ensure_ascii = ensure_ascii super(JSONProperty, self).__init__(*args, **kwargs) @validator - def inflate(self, value): + def inflate(self, value: Any) -> Any: return json.loads(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> str: return json.dumps(value, ensure_ascii=self.ensure_ascii) @@ -563,7 +580,7 @@ class AliasProperty(property, Property): Alias another existing property """ - def __init__(self, to=None): + def __init__(self, to: str): """ Create new alias @@ -575,30 +592,38 @@ def __init__(self, to=None): self.required = False self.has_default = False - def aliased_to(self): + def aliased_to(self) -> str: return self.target - def __get__(self, obj: Any, _type: Optional[Any] = None): + def __get__(self, obj: Any, _type: Optional[Any] = None) -> Property: return getattr(obj, self.aliased_to()) if obj else self - def __set__(self, obj, value): + def __set__(self, obj: Any, value: Property) -> None: setattr(obj, self.aliased_to(), value) @property - def index(self): + def index(self) -> bool: return getattr(self.owner, self.aliased_to()).index + @index.setter + def index(self, value: bool) -> None: + raise AttributeError("Cannot set read-only property 'index'") + @property - def unique_index(self): + def unique_index(self) -> bool: return getattr(self.owner, self.aliased_to()).unique_index + @unique_index.setter + def unique_index(self, value: bool) -> None: + raise AttributeError("Cannot set read-only property 'unique_index'") + class UniqueIdProperty(Property): """ A unique identifier, a randomly generated uid (uuid4) with a unique index """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for item in ["required", "unique_index", "index", "default"]: if item in kwargs: raise ValueError( @@ -610,9 +635,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) @validator - def inflate(self, value): + def inflate(self, value: Any) -> str: return str(value) @validator - def deflate(self, value): + def deflate(self, value: Any) -> str: return str(value) diff --git a/neomodel/util.py b/neomodel/util.py index a62e988e..36fffdd2 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -1,12 +1,16 @@ import warnings +from types import FrameType +from typing import Any, Callable, Optional + +from neo4j.graph import Entity OUTGOING, INCOMING, EITHER = 1, -1, 0 -def deprecated(message): +def deprecated(message: str) -> Callable: # pylint:disable=invalid-name - def f__(f): - def f_(*args, **kwargs): + def f__(f: Callable) -> Callable: + def f_(*args, **kwargs) -> Any: # type: ignore warnings.warn(message, category=DeprecationWarning, stacklevel=2) return f(*args, **kwargs) @@ -18,12 +22,12 @@ def f_(*args, **kwargs): return f__ -def classproperty(f): +def classproperty(f: Callable) -> Any: class cpf: - def __init__(self, getter): + def __init__(self, getter: Callable) -> None: self.getter = getter - def __get__(self, obj, type=None): + def __get__(self, obj: Any, type: Optional[Any] = None) -> Any: return self.getter(type) return cpf(f) @@ -31,21 +35,21 @@ def __get__(self, obj, type=None): # Just used for error messages class _UnsavedNode: - def __repr__(self): + def __repr__(self) -> str: return "" - def __str__(self): + def __str__(self) -> str: return self.__repr__() -def get_graph_entity_properties(entity): +def get_graph_entity_properties(entity: Entity) -> dict: """ Get the properties from a neo4j.graph.Entity (neo4j.graph.Node or neo4j.graph.Relationship) object. """ return entity._properties -def enumerate_traceback(initial_frame): +def enumerate_traceback(initial_frame: Optional[FrameType] = None) -> Any: depth, frame = 0, initial_frame while frame is not None: yield depth, frame @@ -53,7 +57,7 @@ def enumerate_traceback(initial_frame): depth += 1 -def version_tag_to_integer(version_tag): +def version_tag_to_integer(version_tag: str) -> int: """ Converts a version string to an integer representation to allow for quick comparisons between versions. From 23f5c03916d351a3d5483b39cc782ce54ccfa694 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 11 Dec 2024 17:49:52 +0100 Subject: [PATCH 13/20] Add more type hints --- neomodel/async_/core.py | 304 ++++++++++++++++++++++++---------------- neomodel/hooks.py | 7 +- neomodel/sync_/core.py | 298 +++++++++++++++++++++++---------------- 3 files changed, 362 insertions(+), 247 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 6c2def38..0ccabe11 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -7,7 +7,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Callable, Optional, Sequence, Type, Union +from typing import Any, Callable, Optional, TextIO, Type, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -54,7 +54,7 @@ # make sure the connection url has been set prior to executing the wrapped function -def ensure_connection(func): +def ensure_connection(func: Callable) -> Callable: """Decorator that ensures a connection is established before executing the decorated function. Args: @@ -65,7 +65,7 @@ def ensure_connection(func): """ - async def wrapper(self, *args, **kwargs): + async def wrapper(self, *args: Any, **kwargs: Any) -> Callable: # Sort out where to find url if hasattr(self, "db"): _db = self.db @@ -89,7 +89,7 @@ class AsyncDatabase(local): """ _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} - _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict] = {} + _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} def __init__(self): self._active_transaction: Optional[AsyncTransaction] = None @@ -105,7 +105,7 @@ def __init__(self): async def set_connection( self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None - ): + ) -> None: """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -200,7 +200,7 @@ def _parse_driver_from_url(self, url: str) -> None: else: self._database_name = database_name - async def close_connection(self): + async def close_connection(self) -> None: """ Closes the currently open driver. The driver should always be closed at the end of the application's lifecyle. @@ -213,36 +213,36 @@ async def close_connection(self): self.driver = None @property - async def database_version(self): + async def database_version(self) -> Optional[str]: if self._database_version is None: await self._update_database_version() return self._database_version @property - async def database_edition(self): + async def database_edition(self) -> Optional[str]: if self._database_edition is None: await self._update_database_version() return self._database_edition @property - def transaction(self): + def transaction(self) -> "AsyncTransactionProxy": """ Returns the current transaction object """ return AsyncTransactionProxy(self) @property - def write_transaction(self): + def write_transaction(self) -> "AsyncTransactionProxy": return AsyncTransactionProxy(self, access_mode="WRITE") @property - def read_transaction(self): + def read_transaction(self) -> "AsyncTransactionProxy": return AsyncTransactionProxy(self, access_mode="READ") @property - def parallel_read_transaction(self): + def parallel_read_transaction(self) -> "AsyncTransactionProxy": return AsyncTransactionProxy(self, access_mode="READ", parallel_runtime=True) async def impersonate(self, user: str) -> "ImpersonationHandler": @@ -262,7 +262,7 @@ async def impersonate(self, user: str) -> "ImpersonationHandler": return ImpersonationHandler(self, impersonated_user=user) @ensure_connection - async def begin(self, access_mode=None, **parameters): + async def begin(self, access_mode: str = "WRITE", **parameters: Any) -> None: """ Begins a new transaction. Raises SystemError if a transaction is already active. """ @@ -285,7 +285,7 @@ async def begin(self, access_mode=None, **parameters): self._active_transaction = await self._session.begin_transaction() @ensure_connection - async def commit(self): + async def commit(self) -> Bookmarks: """ Commits the current transaction and closes its session @@ -313,7 +313,7 @@ async def commit(self): return last_bookmarks @ensure_connection - async def rollback(self): + async def rollback(self) -> None: """ Rolls back the current transaction and closes its session """ @@ -332,7 +332,7 @@ async def rollback(self): self._active_transaction = None self._session = None - async def _update_database_version(self): + async def _update_database_version(self) -> None: """ Updates the database server information when it is required """ @@ -346,7 +346,7 @@ async def _update_database_version(self): # The database server is not running yet pass - def _object_resolution(self, object_to_resolve): + def _object_resolution(self, object_to_resolve: Any) -> Any: """ Performs in place automatic object resolution on a result returned by cypher_query. @@ -421,7 +421,7 @@ def _object_resolution(self, object_to_resolve): return object_to_resolve - def _result_resolution(self, result_list): + def _result_resolution(self, result_list: list) -> list: """ Performs in place automatic object resolution on a set of results returned by cypher_query. @@ -452,12 +452,12 @@ def _result_resolution(self, result_list): @ensure_connection async def cypher_query( self, - query, - params=None, - handle_unique=True, - retry_on_session_expire=False, - resolve_objects=False, - ): + query: str, + params: Optional[dict[str, Any]] = None, + handle_unique: bool = True, + retry_on_session_expire: bool = False, + resolve_objects: bool = False, + ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: """ Runs a query on the database and returns a list of results and their headers. @@ -475,6 +475,8 @@ async def cypher_query( :return: A tuple containing a list of results and a tuple of headers. """ + if params is None: + params = {} if self._active_transaction: # Use current transaction if a transaction is currently active results, meta = await self._run_cypher_query( @@ -508,18 +510,18 @@ async def cypher_query( async def _run_cypher_query( self, session: Union[AsyncSession, AsyncTransaction], - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ): + query: str, + params: dict[str, Any], + handle_unique: bool, + retry_on_session_expire: bool, + resolve_objects: bool, + ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: try: # Retrieve the data start = time.time() if self._parallel_runtime: query = "CYPHER runtime=parallel " + query - response: AsyncResult = await session.run(query, params) + response: AsyncResult = await session.run(query=query, parameters=params) results, meta = [list(r.values()) async for r in response], response.keys() end = time.time() @@ -529,15 +531,14 @@ async def _run_cypher_query( except ClientError as e: if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - if ( - hasattr(e, "message") - and e.message is not None - and "already exists with label" in e.message - and handle_unique - ): - raise UniqueProperty(e.message) from e + if hasattr(e, "message") and e.message is not None: + if "already exists with label" in e.message and handle_unique: + raise UniqueProperty(e.message) from e + raise ConstraintValidationFailed(e.message) from e + raise ConstraintValidationFailed( + "A constraint validation failed" + ) from e - raise ConstraintValidationFailed(e.message) from e exc_info = sys.exc_info() if exc_info[1] is not None and exc_info[2] is not None: raise exc_info[1].with_traceback(exc_info[2]) @@ -568,16 +569,30 @@ async def _run_cypher_query( async def get_id_method(self) -> str: db_version = await self.database_version + if db_version is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. + """ + ) if db_version.startswith("4"): return "id" else: return "elementId" - async def parse_element_id(self, element_id: str): + async def parse_element_id(self, element_id: str) -> Union[str, int]: db_version = await self.database_version + if db_version is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. + """ + ) return int(element_id) if db_version.startswith("4") else element_id - async def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: + async def list_indexes(self, exclude_token_lookup: bool = False) -> list[dict]: """Returns all indexes existing in the database Arguments: @@ -596,7 +611,7 @@ async def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: return indexes_as_dict - async def list_constraints(self) -> Sequence[dict]: + async def list_constraints(self) -> list[dict]: """Returns all constraints existing in the database Returns: @@ -618,6 +633,13 @@ async def version_is_higher_than(self, version_tag: str) -> bool: bool: True if the database version is higher or equal to the given version """ db_version = await self.database_version + if db_version is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. + """ + ) return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag) @ensure_connection @@ -628,6 +650,13 @@ async def edition_is_enterprise(self) -> bool: bool: True if the database edition is enterprise """ edition = await self.database_edition + if edition is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server edition is not known. + This might mean that the database server is offline. + """ + ) return edition == "enterprise" @ensure_connection @@ -642,10 +671,12 @@ async def parallel_runtime_available(self) -> bool: and await self.edition_is_enterprise() ) - async def change_neo4j_password(self, user, new_password): + async def change_neo4j_password(self, user: str, new_password: str) -> None: await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") - async def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): + async def clear_neo4j_database( + self, clear_constraints: bool = False, clear_indexes: bool = False + ) -> None: await self.cypher_query( """ MATCH (a) @@ -658,7 +689,9 @@ async def clear_neo4j_database(self, clear_constraints=False, clear_indexes=Fals if clear_indexes: await drop_indexes() - async def drop_constraints(self, quiet=True, stdout=None): + async def drop_constraints( + self, quiet: bool = True, stdout: Optional[TextIO] = None + ) -> None: """ Discover and drop all constraints. @@ -684,7 +717,9 @@ async def drop_constraints(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - async def drop_indexes(self, quiet=True, stdout=None): + async def drop_indexes( + self, quiet: bool = True, stdout: Optional[TextIO] = None + ) -> None: """ Discover and drop all indexes, except the automatically created token lookup indexes. @@ -704,7 +739,7 @@ async def drop_indexes(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - async def remove_all_labels(self, stdout=None): + async def remove_all_labels(self, stdout: Optional[TextIO] = None) -> None: """ Calls functions for dropping constraints and indexes. @@ -721,7 +756,7 @@ async def remove_all_labels(self, stdout=None): stdout.write("Dropping indexes...\n") await self.drop_indexes(quiet=False, stdout=stdout) - async def install_all_labels(self, stdout=None): + async def install_all_labels(self, stdout: Optional[TextIO] = None) -> None: """ Discover all subclasses of StructuredNode in your application and execute install_labels on each. Note: code must be loaded (imported) in order for a class to be discovered. @@ -752,7 +787,9 @@ def subsub(cls): # recursively return all subclasses stdout.write(f"Finished {i} classes.\n") - async def install_labels(self, cls, quiet=True, stdout=None): + async def install_labels( + self, cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None + ) -> None: """ Setup labels with indexes and constraints for a given class @@ -763,27 +800,26 @@ async def install_labels(self, cls, quiet=True, stdout=None): :type: bool :return: None """ - if not stdout or stdout is None: - stdout = sys.stdout + _stdout = stdout if stdout else sys.stdout if not hasattr(cls, "__label__"): if not quiet: - stdout.write( + _stdout.write( f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" ) return for name, property in cls.defined_properties(aliases=False, rels=False).items(): - await self._install_node(cls, name, property, quiet, stdout) + await self._install_node(cls, name, property, quiet, _stdout) for _, relationship in cls.defined_properties( aliases=False, rels=True, properties=False ).items(): - await self._install_relationship(cls, relationship, quiet, stdout) + await self._install_relationship(cls, relationship, quiet, _stdout) async def _create_node_index( - self, target_cls, property_name: str, stdout, quiet: bool - ): + self, target_cls: Any, property_name: str, stdout: TextIO, quiet: bool + ) -> None: label = target_cls.__label__ index_name = f"index_{label}_{property_name}" if not quiet: @@ -805,12 +841,12 @@ async def _create_node_index( async def _create_node_fulltext_index( self, - target_cls, + target_cls: Any, property_name: str, - stdout, + stdout: TextIO, fulltext_index: FulltextIndex, quiet: bool, - ): + ) -> None: if await self.version_is_higher_than("5.16"): label = target_cls.__label__ index_name = f"fulltext_index_{label}_{property_name}" @@ -844,12 +880,12 @@ async def _create_node_fulltext_index( async def _create_node_vector_index( self, - target_cls, + target_cls: Any, property_name: str, - stdout, + stdout: TextIO, vector_index: VectorIndex, quiet: bool, - ): + ) -> None: if await self.version_is_higher_than("5.15"): label = target_cls.__label__ index_name = f"vector_index_{label}_{property_name}" @@ -882,7 +918,7 @@ async def _create_node_vector_index( ) async def _create_node_constraint( - self, target_cls, property_name: str, stdout, quiet: bool + self, target_cls: Any, property_name: str, stdout: TextIO, quiet: bool ): label = target_cls.__label__ constraint_name = f"constraint_unique_{label}_{property_name}" @@ -907,12 +943,12 @@ async def _create_node_constraint( async def _create_relationship_index( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, quiet: bool, - ): + ) -> None: index_name = f"index_{relationship_type}_{property_name}" if not quiet: stdout.write( @@ -934,13 +970,13 @@ async def _create_relationship_index( async def _create_relationship_fulltext_index( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, fulltext_index: FulltextIndex, quiet: bool, - ): + ) -> None: if await self.version_is_higher_than("5.16"): index_name = f"fulltext_index_{relationship_type}_{property_name}" if not quiet: @@ -974,13 +1010,13 @@ async def _create_relationship_fulltext_index( async def _create_relationship_vector_index( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, vector_index: VectorIndex, quiet: bool, - ): + ) -> None: if await self.version_is_higher_than("5.18"): index_name = f"vector_index_{relationship_type}_{property_name}" if not quiet: @@ -1014,12 +1050,12 @@ async def _create_relationship_vector_index( async def _create_relationship_constraint( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, quiet: bool, - ): + ) -> None: if await self.version_is_higher_than("5.7"): constraint_name = f"constraint_unique_{relationship_type}_{property_name}" if not quiet: @@ -1044,7 +1080,9 @@ async def _create_relationship_constraint( f"Unique indexes on relationships are not supported in Neo4j version {await self.database_version}. Please upgrade to Neo4j 5.7 or higher." ) - async def _install_node(self, cls, name, property, quiet, stdout): + async def _install_node( + self, cls: Any, name: str, property: Property, quiet: bool, stdout: TextIO + ) -> None: # Create indexes and constraints for node property db_property = property.get_db_property_name(name) if property.index: @@ -1074,7 +1112,9 @@ async def _install_node(self, cls, name, property, quiet, stdout): quiet=quiet, ) - async def _install_relationship(self, cls, relationship, quiet, stdout): + async def _install_relationship( + self, cls: Any, relationship: Any, quiet: bool, stdout: TextIO + ) -> None: # Create indexes and constraints for relationship property relationship_cls = relationship.definition["model"] if relationship_cls is not None: @@ -1130,7 +1170,9 @@ async def _install_relationship(self, cls, relationship, quiet, stdout): # Deprecated methods -async def change_neo4j_password(db: AsyncDatabase, user, new_password): +async def change_neo4j_password( + db: AsyncDatabase, user: str, new_password: str +) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1142,8 +1184,8 @@ async def change_neo4j_password(db: AsyncDatabase, user, new_password): async def clear_neo4j_database( - db: AsyncDatabase, clear_constraints=False, clear_indexes=False -): + db: AsyncDatabase, clear_constraints: bool = False, clear_indexes: bool = False +) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1154,7 +1196,7 @@ async def clear_neo4j_database( await db.clear_neo4j_database(clear_constraints, clear_indexes) -async def drop_constraints(quiet=True, stdout=None): +async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1165,7 +1207,7 @@ async def drop_constraints(quiet=True, stdout=None): await adb.drop_constraints(quiet, stdout) -async def drop_indexes(quiet=True, stdout=None): +async def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1176,7 +1218,7 @@ async def drop_indexes(quiet=True, stdout=None): await adb.drop_indexes(quiet, stdout) -async def remove_all_labels(stdout=None): +async def remove_all_labels(stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1187,7 +1229,9 @@ async def remove_all_labels(stdout=None): await adb.remove_all_labels(stdout) -async def install_labels(cls, quiet=True, stdout=None): +async def install_labels( + cls, quiet: bool = True, stdout: Optional[TextIO] = None +) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1198,7 +1242,7 @@ async def install_labels(cls, quiet=True, stdout=None): await adb.install_labels(cls, quiet, stdout) -async def install_all_labels(stdout=None): +async def install_all_labels(stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1223,7 +1267,7 @@ def __init__( self.parallel_runtime: Optional[bool] = parallel_runtime @ensure_connection - async def __aenter__(self): + async def __aenter__(self) -> "AsyncTransactionProxy": if self.parallel_runtime and not await self.db.parallel_runtime_available(): warnings.warn( "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " @@ -1236,7 +1280,7 @@ async def __aenter__(self): self.bookmarks = None return self - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.db._parallel_runtime = False if exc_value: await self.db.rollback() @@ -1250,28 +1294,28 @@ async def __aexit__(self, exc_type, exc_value, traceback): if not exc_value: self.last_bookmark = await self.db.commit() - def __call__(self, func): + def __call__(self, func: Callable) -> Callable: if AsyncUtil.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) @wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Callable: async with self: return await func(*args, **kwargs) return wrapper @property - def with_bookmark(self): + def with_bookmark(self) -> "BookmarkingAsyncTransactionProxy": return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): - def __call__(self, func): + def __call__(self, func: Callable) -> Callable: if AsyncUtil.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, None]: self.bookmarks = kwargs.pop("bookmarks", None) async with self: @@ -1288,19 +1332,21 @@ def __init__(self, db: AsyncDatabase, impersonated_user: str): self.db = db self.impersonated_user = impersonated_user - def __enter__(self): + def __enter__(self) -> "ImpersonationHandler": self.db.impersonated_user = self.impersonated_user return self - def __exit__(self, exception_type, exception_value, exception_traceback): + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: self.db.impersonated_user = None print("\nException type:", exception_type) print("\nException value:", exception_value) print("\nTraceback:", exception_traceback) - def __call__(self, func): - def wrapper(*args, **kwargs): + def __call__(self, func: Callable) -> Callable: + def wrapper(*args: Any, **kwargs: Any) -> Callable: with self: return func(*args, **kwargs) @@ -1318,7 +1364,9 @@ class NodeMeta(type): defined_properties: Callable[..., dict[str, Any]] - def __new__(mcs, name, bases, namespace): + def __new__( + mcs: "NodeMeta", name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> Any: namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) cls = super().__new__(mcs, name, bases, namespace) cls.DoesNotExist._model_class = cls @@ -1377,7 +1425,7 @@ def __new__(mcs, name, bases, namespace): return cls -def build_class_registry(cls): +def build_class_registry(cls) -> None: base_label_set = frozenset(cls.inherited_labels()) optional_label_set = set(cls.inherited_optional_labels()) @@ -1428,7 +1476,7 @@ class AsyncStructuredNode(NodeBase): # magic methods - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): if "deleted" in kwargs: raise ValueError("deleted property is reserved for neomodel") @@ -1450,19 +1498,19 @@ def __eq__(self, other: Any) -> bool: return self.element_id == other.element_id return id(self) == id(other) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self}>" - def __str__(self): + def __str__(self) -> str: return repr(self.__properties__) # dynamic properties @classproperty - def nodes(cls): + def nodes(self) -> Any: """ Returns a NodeSet object representing all nodes of the classes label :return: NodeSet @@ -1470,17 +1518,17 @@ def nodes(cls): """ from neomodel.async_.match import AsyncNodeSet - return AsyncNodeSet(cls) + return AsyncNodeSet(self) @property - def element_id(self): + def element_id(self) -> Optional[Any]: if hasattr(self, "element_id_property"): return self.element_id_property return None # Version 4.4 support - id is deprecated in version 5.x @property - def id(self): + def id(self) -> int: try: return int(self.element_id_property) except (TypeError, ValueError): @@ -1499,8 +1547,12 @@ def was_saved(self) -> bool: @classmethod async def _build_merge_query( - cls, merge_params, update_existing=False, lazy=False, relationship=None - ): + cls, + merge_params: tuple[dict[str, Any], ...], + update_existing: bool = False, + lazy: bool = False, + relationship: Optional[Any] = None, + ) -> tuple[str, dict[str, Any]]: """ Get a tuple of a CYPHER query and a params dict for the specified MERGE query. @@ -1510,7 +1562,7 @@ async def _build_merge_query( :type update_existing: bool :rtype: tuple """ - query_params = dict(merge_params=merge_params) + query_params: dict[str, Any] = {"merge_params": merge_params} n_merge_labels = ":".join(cls.inherited_labels()) n_merge_prm = ", ".join( ( @@ -1536,6 +1588,10 @@ async def _build_merge_query( from neomodel.async_.match import _rel_helper + if relationship.source.element_id is None: + raise RuntimeError( + "Could not identify the relationship source, its element id was None." + ) query_params["source_id"] = await adb.parse_element_id( relationship.source.element_id ) @@ -1564,7 +1620,7 @@ async def _build_merge_query( return query, query_params @classmethod - async def create(cls, *props, **kwargs): + async def create(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: """ Call to CREATE with parameters map. A new instance will be created and saved. @@ -1608,7 +1664,7 @@ async def create(cls, *props, **kwargs): return nodes @classmethod - async def create_or_update(cls, *props, **kwargs): + async def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, this is an atomic operation. If an instance already exists all optional properties specified will be updated. @@ -1621,7 +1677,7 @@ async def create_or_update(cls, *props, **kwargs): :param lazy: False by default, specify True to get nodes with id only without the parameters. :rtype: list """ - lazy = kwargs.get("lazy", False) + lazy: bool = kwargs.get("lazy", False) relationship = kwargs.get("relationship") # build merge query, make sure to update only explicitly specified properties @@ -1638,7 +1694,7 @@ async def create_or_update(cls, *props, **kwargs): } ) query, params = await cls._build_merge_query( - create_or_update_params, + tuple(create_or_update_params), update_existing=True, relationship=relationship, lazy=lazy, @@ -1655,7 +1711,9 @@ async def create_or_update(cls, *props, **kwargs): results = await adb.cypher_query(query, params) return [cls.inflate(r[0]) for r in results[0]] - async def cypher(self, query, params=None): + async def cypher( + self, query: str, params: Optional[dict[str, Any]] = None + ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: """ Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. @@ -1663,14 +1721,14 @@ async def cypher(self, query, params=None): :type: string :param params: query parameters :type: dict - :return: list containing query results - :rtype: list + :return: tuple containing a list of query results, and the meta information as a tuple + :rtype: tuple """ self._pre_action_check("cypher") - params = params or {} + _params = params or {} element_id = await adb.parse_element_id(self.element_id) - params.update({"self": element_id}) - return await adb.cypher_query(query, params) + _params.update({"self": element_id}) + return await adb.cypher_query(query, _params) @hooks async def delete(self): diff --git a/neomodel/hooks.py b/neomodel/hooks.py index dffaa73a..2481ca34 100644 --- a/neomodel/hooks.py +++ b/neomodel/hooks.py @@ -1,14 +1,15 @@ from functools import wraps +from typing import Any, Callable -def _exec_hook(hook_name, self): +def _exec_hook(hook_name: str, self: Any) -> None: if hasattr(self, hook_name): getattr(self, hook_name)() -def hooks(fn): +def hooks(fn: Callable) -> Callable: @wraps(fn) - def hooked(self): + def hooked(self: Any) -> Callable: fn_name = getattr(fn, "func_name", fn.__name__) _exec_hook("pre_" + fn_name, self) val = fn(self) diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index e0bd4029..31987106 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -7,7 +7,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Callable, Optional, Sequence, Type, Union +from typing import Any, Callable, Optional, TextIO, Type, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -54,7 +54,7 @@ # make sure the connection url has been set prior to executing the wrapped function -def ensure_connection(func): +def ensure_connection(func: Callable) -> Callable: """Decorator that ensures a connection is established before executing the decorated function. Args: @@ -65,7 +65,7 @@ def ensure_connection(func): """ - def wrapper(self, *args, **kwargs): + def wrapper(self, *args: Any, **kwargs: Any) -> Callable: # Sort out where to find url if hasattr(self, "db"): _db = self.db @@ -89,7 +89,7 @@ class Database(local): """ _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} - _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict] = {} + _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} def __init__(self): self._active_transaction: Optional[Transaction] = None @@ -105,7 +105,7 @@ def __init__(self): def set_connection( self, url: Optional[str] = None, driver: Optional[Driver] = None - ): + ) -> None: """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -200,7 +200,7 @@ def _parse_driver_from_url(self, url: str) -> None: else: self._database_name = database_name - def close_connection(self): + def close_connection(self) -> None: """ Closes the currently open driver. The driver should always be closed at the end of the application's lifecyle. @@ -213,36 +213,36 @@ def close_connection(self): self.driver = None @property - def database_version(self): + def database_version(self) -> Optional[str]: if self._database_version is None: self._update_database_version() return self._database_version @property - def database_edition(self): + def database_edition(self) -> Optional[str]: if self._database_edition is None: self._update_database_version() return self._database_edition @property - def transaction(self): + def transaction(self) -> "TransactionProxy": """ Returns the current transaction object """ return TransactionProxy(self) @property - def write_transaction(self): + def write_transaction(self) -> "TransactionProxy": return TransactionProxy(self, access_mode="WRITE") @property - def read_transaction(self): + def read_transaction(self) -> "TransactionProxy": return TransactionProxy(self, access_mode="READ") @property - def parallel_read_transaction(self): + def parallel_read_transaction(self) -> "TransactionProxy": return TransactionProxy(self, access_mode="READ", parallel_runtime=True) def impersonate(self, user: str) -> "ImpersonationHandler": @@ -262,7 +262,7 @@ def impersonate(self, user: str) -> "ImpersonationHandler": return ImpersonationHandler(self, impersonated_user=user) @ensure_connection - def begin(self, access_mode=None, **parameters): + def begin(self, access_mode: str = "WRITE", **parameters: Any) -> None: """ Begins a new transaction. Raises SystemError if a transaction is already active. """ @@ -285,7 +285,7 @@ def begin(self, access_mode=None, **parameters): self._active_transaction = self._session.begin_transaction() @ensure_connection - def commit(self): + def commit(self) -> Bookmarks: """ Commits the current transaction and closes its session @@ -313,7 +313,7 @@ def commit(self): return last_bookmarks @ensure_connection - def rollback(self): + def rollback(self) -> None: """ Rolls back the current transaction and closes its session """ @@ -332,7 +332,7 @@ def rollback(self): self._active_transaction = None self._session = None - def _update_database_version(self): + def _update_database_version(self) -> None: """ Updates the database server information when it is required """ @@ -346,7 +346,7 @@ def _update_database_version(self): # The database server is not running yet pass - def _object_resolution(self, object_to_resolve): + def _object_resolution(self, object_to_resolve: Any) -> Any: """ Performs in place automatic object resolution on a result returned by cypher_query. @@ -421,7 +421,7 @@ def _object_resolution(self, object_to_resolve): return object_to_resolve - def _result_resolution(self, result_list): + def _result_resolution(self, result_list: list) -> list: """ Performs in place automatic object resolution on a set of results returned by cypher_query. @@ -452,12 +452,12 @@ def _result_resolution(self, result_list): @ensure_connection def cypher_query( self, - query, - params=None, - handle_unique=True, - retry_on_session_expire=False, - resolve_objects=False, - ): + query: str, + params: Optional[dict[str, Any]] = None, + handle_unique: bool = True, + retry_on_session_expire: bool = False, + resolve_objects: bool = False, + ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: """ Runs a query on the database and returns a list of results and their headers. @@ -475,6 +475,8 @@ def cypher_query( :return: A tuple containing a list of results and a tuple of headers. """ + if params is None: + params = {} if self._active_transaction: # Use current transaction if a transaction is currently active results, meta = self._run_cypher_query( @@ -508,18 +510,18 @@ def cypher_query( def _run_cypher_query( self, session: Union[Session, Transaction], - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ): + query: str, + params: dict[str, Any], + handle_unique: bool, + retry_on_session_expire: bool, + resolve_objects: bool, + ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: try: # Retrieve the data start = time.time() if self._parallel_runtime: query = "CYPHER runtime=parallel " + query - response: Result = session.run(query, params) + response: Result = session.run(query=query, parameters=params) results, meta = [list(r.values()) for r in response], response.keys() end = time.time() @@ -529,15 +531,14 @@ def _run_cypher_query( except ClientError as e: if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - if ( - hasattr(e, "message") - and e.message is not None - and "already exists with label" in e.message - and handle_unique - ): - raise UniqueProperty(e.message) from e + if hasattr(e, "message") and e.message is not None: + if "already exists with label" in e.message and handle_unique: + raise UniqueProperty(e.message) from e + raise ConstraintValidationFailed(e.message) from e + raise ConstraintValidationFailed( + "A constraint validation failed" + ) from e - raise ConstraintValidationFailed(e.message) from e exc_info = sys.exc_info() if exc_info[1] is not None and exc_info[2] is not None: raise exc_info[1].with_traceback(exc_info[2]) @@ -568,16 +569,30 @@ def _run_cypher_query( def get_id_method(self) -> str: db_version = self.database_version + if db_version is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. + """ + ) if db_version.startswith("4"): return "id" else: return "elementId" - def parse_element_id(self, element_id: str): + def parse_element_id(self, element_id: str) -> Union[str, int]: db_version = self.database_version + if db_version is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. + """ + ) return int(element_id) if db_version.startswith("4") else element_id - def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: + def list_indexes(self, exclude_token_lookup: bool = False) -> list[dict]: """Returns all indexes existing in the database Arguments: @@ -596,7 +611,7 @@ def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: return indexes_as_dict - def list_constraints(self) -> Sequence[dict]: + def list_constraints(self) -> list[dict]: """Returns all constraints existing in the database Returns: @@ -618,6 +633,13 @@ def version_is_higher_than(self, version_tag: str) -> bool: bool: True if the database version is higher or equal to the given version """ db_version = self.database_version + if db_version is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server version is not known. + This might mean that the database server is offline. + """ + ) return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag) @ensure_connection @@ -628,6 +650,13 @@ def edition_is_enterprise(self) -> bool: bool: True if the database edition is enterprise """ edition = self.database_edition + if edition is None: + raise RuntimeError( + """ + Unable to perform this operation because the database server edition is not known. + This might mean that the database server is offline. + """ + ) return edition == "enterprise" @ensure_connection @@ -639,10 +668,12 @@ def parallel_runtime_available(self) -> bool: """ return self.version_is_higher_than("5.13") and self.edition_is_enterprise() - def change_neo4j_password(self, user, new_password): + def change_neo4j_password(self, user: str, new_password: str) -> None: self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") - def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): + def clear_neo4j_database( + self, clear_constraints: bool = False, clear_indexes: bool = False + ) -> None: self.cypher_query( """ MATCH (a) @@ -655,7 +686,9 @@ def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): if clear_indexes: drop_indexes() - def drop_constraints(self, quiet=True, stdout=None): + def drop_constraints( + self, quiet: bool = True, stdout: Optional[TextIO] = None + ) -> None: """ Discover and drop all constraints. @@ -681,7 +714,7 @@ def drop_constraints(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - def drop_indexes(self, quiet=True, stdout=None): + def drop_indexes(self, quiet: bool = True, stdout: Optional[TextIO] = None) -> None: """ Discover and drop all indexes, except the automatically created token lookup indexes. @@ -701,7 +734,7 @@ def drop_indexes(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - def remove_all_labels(self, stdout=None): + def remove_all_labels(self, stdout: Optional[TextIO] = None) -> None: """ Calls functions for dropping constraints and indexes. @@ -718,7 +751,7 @@ def remove_all_labels(self, stdout=None): stdout.write("Dropping indexes...\n") self.drop_indexes(quiet=False, stdout=stdout) - def install_all_labels(self, stdout=None): + def install_all_labels(self, stdout: Optional[TextIO] = None) -> None: """ Discover all subclasses of StructuredNode in your application and execute install_labels on each. Note: code must be loaded (imported) in order for a class to be discovered. @@ -749,7 +782,9 @@ def subsub(cls): # recursively return all subclasses stdout.write(f"Finished {i} classes.\n") - def install_labels(self, cls, quiet=True, stdout=None): + def install_labels( + self, cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None + ) -> None: """ Setup labels with indexes and constraints for a given class @@ -760,25 +795,26 @@ def install_labels(self, cls, quiet=True, stdout=None): :type: bool :return: None """ - if not stdout or stdout is None: - stdout = sys.stdout + _stdout = stdout if stdout else sys.stdout if not hasattr(cls, "__label__"): if not quiet: - stdout.write( + _stdout.write( f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" ) return for name, property in cls.defined_properties(aliases=False, rels=False).items(): - self._install_node(cls, name, property, quiet, stdout) + self._install_node(cls, name, property, quiet, _stdout) for _, relationship in cls.defined_properties( aliases=False, rels=True, properties=False ).items(): - self._install_relationship(cls, relationship, quiet, stdout) + self._install_relationship(cls, relationship, quiet, _stdout) - def _create_node_index(self, target_cls, property_name: str, stdout, quiet: bool): + def _create_node_index( + self, target_cls: Any, property_name: str, stdout: TextIO, quiet: bool + ) -> None: label = target_cls.__label__ index_name = f"index_{label}_{property_name}" if not quiet: @@ -800,12 +836,12 @@ def _create_node_index(self, target_cls, property_name: str, stdout, quiet: bool def _create_node_fulltext_index( self, - target_cls, + target_cls: Any, property_name: str, - stdout, + stdout: TextIO, fulltext_index: FulltextIndex, quiet: bool, - ): + ) -> None: if self.version_is_higher_than("5.16"): label = target_cls.__label__ index_name = f"fulltext_index_{label}_{property_name}" @@ -839,12 +875,12 @@ def _create_node_fulltext_index( def _create_node_vector_index( self, - target_cls, + target_cls: Any, property_name: str, - stdout, + stdout: TextIO, vector_index: VectorIndex, quiet: bool, - ): + ) -> None: if self.version_is_higher_than("5.15"): label = target_cls.__label__ index_name = f"vector_index_{label}_{property_name}" @@ -877,7 +913,7 @@ def _create_node_vector_index( ) def _create_node_constraint( - self, target_cls, property_name: str, stdout, quiet: bool + self, target_cls: Any, property_name: str, stdout: TextIO, quiet: bool ): label = target_cls.__label__ constraint_name = f"constraint_unique_{label}_{property_name}" @@ -902,12 +938,12 @@ def _create_node_constraint( def _create_relationship_index( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, quiet: bool, - ): + ) -> None: index_name = f"index_{relationship_type}_{property_name}" if not quiet: stdout.write( @@ -929,13 +965,13 @@ def _create_relationship_index( def _create_relationship_fulltext_index( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, fulltext_index: FulltextIndex, quiet: bool, - ): + ) -> None: if self.version_is_higher_than("5.16"): index_name = f"fulltext_index_{relationship_type}_{property_name}" if not quiet: @@ -969,13 +1005,13 @@ def _create_relationship_fulltext_index( def _create_relationship_vector_index( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, vector_index: VectorIndex, quiet: bool, - ): + ) -> None: if self.version_is_higher_than("5.18"): index_name = f"vector_index_{relationship_type}_{property_name}" if not quiet: @@ -1009,12 +1045,12 @@ def _create_relationship_vector_index( def _create_relationship_constraint( self, relationship_type: str, - target_cls, - relationship_cls, + target_cls: Any, + relationship_cls: Any, property_name: str, - stdout, + stdout: TextIO, quiet: bool, - ): + ) -> None: if self.version_is_higher_than("5.7"): constraint_name = f"constraint_unique_{relationship_type}_{property_name}" if not quiet: @@ -1039,7 +1075,9 @@ def _create_relationship_constraint( f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." ) - def _install_node(self, cls, name, property, quiet, stdout): + def _install_node( + self, cls: Any, name: str, property: Property, quiet: bool, stdout: TextIO + ) -> None: # Create indexes and constraints for node property db_property = property.get_db_property_name(name) if property.index: @@ -1069,7 +1107,9 @@ def _install_node(self, cls, name, property, quiet, stdout): quiet=quiet, ) - def _install_relationship(self, cls, relationship, quiet, stdout): + def _install_relationship( + self, cls: Any, relationship: Any, quiet: bool, stdout: TextIO + ) -> None: # Create indexes and constraints for relationship property relationship_cls = relationship.definition["model"] if relationship_cls is not None: @@ -1125,7 +1165,7 @@ def _install_relationship(self, cls, relationship, quiet, stdout): # Deprecated methods -def change_neo4j_password(db: Database, user, new_password): +def change_neo4j_password(db: Database, user: str, new_password: str) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1136,7 +1176,9 @@ def change_neo4j_password(db: Database, user, new_password): db.change_neo4j_password(user, new_password) -def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=False): +def clear_neo4j_database( + db: Database, clear_constraints: bool = False, clear_indexes: bool = False +) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1147,7 +1189,7 @@ def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=Fa db.clear_neo4j_database(clear_constraints, clear_indexes) -def drop_constraints(quiet=True, stdout=None): +def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None): deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1158,7 +1200,7 @@ def drop_constraints(quiet=True, stdout=None): db.drop_constraints(quiet, stdout) -def drop_indexes(quiet=True, stdout=None): +def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1169,7 +1211,7 @@ def drop_indexes(quiet=True, stdout=None): db.drop_indexes(quiet, stdout) -def remove_all_labels(stdout=None): +def remove_all_labels(stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1180,7 +1222,7 @@ def remove_all_labels(stdout=None): db.remove_all_labels(stdout) -def install_labels(cls, quiet=True, stdout=None): +def install_labels(cls, quiet: bool = True, stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1191,7 +1233,7 @@ def install_labels(cls, quiet=True, stdout=None): db.install_labels(cls, quiet, stdout) -def install_all_labels(stdout=None): +def install_all_labels(stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1216,7 +1258,7 @@ def __init__( self.parallel_runtime: Optional[bool] = parallel_runtime @ensure_connection - def __enter__(self): + def __enter__(self) -> "TransactionProxy": if self.parallel_runtime and not self.db.parallel_runtime_available(): warnings.warn( "Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. " @@ -1229,7 +1271,7 @@ def __enter__(self): self.bookmarks = None return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.db._parallel_runtime = False if exc_value: self.db.rollback() @@ -1243,28 +1285,28 @@ def __exit__(self, exc_type, exc_value, traceback): if not exc_value: self.last_bookmark = self.db.commit() - def __call__(self, func): + def __call__(self, func: Callable) -> Callable: if Util.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Callable: with self: return func(*args, **kwargs) return wrapper @property - def with_bookmark(self): + def with_bookmark(self) -> "BookmarkingAsyncTransactionProxy": return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) class BookmarkingAsyncTransactionProxy(TransactionProxy): - def __call__(self, func): + def __call__(self, func: Callable) -> Callable: if Util.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, None]: self.bookmarks = kwargs.pop("bookmarks", None) with self: @@ -1281,19 +1323,21 @@ def __init__(self, db: Database, impersonated_user: str): self.db = db self.impersonated_user = impersonated_user - def __enter__(self): + def __enter__(self) -> "ImpersonationHandler": self.db.impersonated_user = self.impersonated_user return self - def __exit__(self, exception_type, exception_value, exception_traceback): + def __exit__( + self, exception_type: Any, exception_value: Any, exception_traceback: Any + ) -> None: self.db.impersonated_user = None print("\nException type:", exception_type) print("\nException value:", exception_value) print("\nTraceback:", exception_traceback) - def __call__(self, func): - def wrapper(*args, **kwargs): + def __call__(self, func: Callable) -> Callable: + def wrapper(*args: Any, **kwargs: Any) -> Callable: with self: return func(*args, **kwargs) @@ -1311,7 +1355,9 @@ class NodeMeta(type): defined_properties: Callable[..., dict[str, Any]] - def __new__(mcs, name, bases, namespace): + def __new__( + mcs: "NodeMeta", name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> Any: namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) cls = super().__new__(mcs, name, bases, namespace) cls.DoesNotExist._model_class = cls @@ -1370,7 +1416,7 @@ def __new__(mcs, name, bases, namespace): return cls -def build_class_registry(cls): +def build_class_registry(cls) -> None: base_label_set = frozenset(cls.inherited_labels()) optional_label_set = set(cls.inherited_optional_labels()) @@ -1419,7 +1465,7 @@ class StructuredNode(NodeBase): # magic methods - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): if "deleted" in kwargs: raise ValueError("deleted property is reserved for neomodel") @@ -1441,19 +1487,19 @@ def __eq__(self, other: Any) -> bool: return self.element_id == other.element_id return id(self) == id(other) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self}>" - def __str__(self): + def __str__(self) -> str: return repr(self.__properties__) # dynamic properties @classproperty - def nodes(cls): + def nodes(self) -> Any: """ Returns a NodeSet object representing all nodes of the classes label :return: NodeSet @@ -1461,17 +1507,17 @@ def nodes(cls): """ from neomodel.sync_.match import NodeSet - return NodeSet(cls) + return NodeSet(self) @property - def element_id(self): + def element_id(self) -> Optional[Any]: if hasattr(self, "element_id_property"): return self.element_id_property return None # Version 4.4 support - id is deprecated in version 5.x @property - def id(self): + def id(self) -> int: try: return int(self.element_id_property) except (TypeError, ValueError): @@ -1490,8 +1536,12 @@ def was_saved(self) -> bool: @classmethod def _build_merge_query( - cls, merge_params, update_existing=False, lazy=False, relationship=None - ): + cls, + merge_params: tuple[dict[str, Any], ...], + update_existing: bool = False, + lazy: bool = False, + relationship: Optional[Any] = None, + ) -> tuple[str, dict[str, Any]]: """ Get a tuple of a CYPHER query and a params dict for the specified MERGE query. @@ -1501,7 +1551,7 @@ def _build_merge_query( :type update_existing: bool :rtype: tuple """ - query_params = dict(merge_params=merge_params) + query_params: dict[str, Any] = {"merge_params": merge_params} n_merge_labels = ":".join(cls.inherited_labels()) n_merge_prm = ", ".join( ( @@ -1527,6 +1577,10 @@ def _build_merge_query( from neomodel.sync_.match import _rel_helper + if relationship.source.element_id is None: + raise RuntimeError( + "Could not identify the relationship source, its element id was None." + ) query_params["source_id"] = db.parse_element_id( relationship.source.element_id ) @@ -1555,7 +1609,7 @@ def _build_merge_query( return query, query_params @classmethod - def create(cls, *props, **kwargs): + def create(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: """ Call to CREATE with parameters map. A new instance will be created and saved. @@ -1599,7 +1653,7 @@ def create(cls, *props, **kwargs): return nodes @classmethod - def create_or_update(cls, *props, **kwargs): + def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, this is an atomic operation. If an instance already exists all optional properties specified will be updated. @@ -1612,7 +1666,7 @@ def create_or_update(cls, *props, **kwargs): :param lazy: False by default, specify True to get nodes with id only without the parameters. :rtype: list """ - lazy = kwargs.get("lazy", False) + lazy: bool = kwargs.get("lazy", False) relationship = kwargs.get("relationship") # build merge query, make sure to update only explicitly specified properties @@ -1629,7 +1683,7 @@ def create_or_update(cls, *props, **kwargs): } ) query, params = cls._build_merge_query( - create_or_update_params, + tuple(create_or_update_params), update_existing=True, relationship=relationship, lazy=lazy, @@ -1646,7 +1700,9 @@ def create_or_update(cls, *props, **kwargs): results = db.cypher_query(query, params) return [cls.inflate(r[0]) for r in results[0]] - def cypher(self, query, params=None): + def cypher( + self, query: str, params: Optional[dict[str, Any]] = None + ) -> tuple[Optional[list], Optional[tuple[str, ...]]]: """ Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. @@ -1654,14 +1710,14 @@ def cypher(self, query, params=None): :type: string :param params: query parameters :type: dict - :return: list containing query results - :rtype: list + :return: tuple containing a list of query results, and the meta information as a tuple + :rtype: tuple """ self._pre_action_check("cypher") - params = params or {} + _params = params or {} element_id = db.parse_element_id(self.element_id) - params.update({"self": element_id}) - return db.cypher_query(query, params) + _params.update({"self": element_id}) + return db.cypher_query(query, _params) @hooks def delete(self): From b918a238755600a3bb85588357cedc8ddc2cf104 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 12 Dec 2024 11:55:18 +0100 Subject: [PATCH 14/20] More mypy type hints --- neomodel/async_/cardinality.py | 31 ++++++++++++------- neomodel/async_/core.py | 44 ++++++++++++++------------- neomodel/async_/property_manager.py | 18 ++++++----- neomodel/sync_/cardinality.py | 31 ++++++++++++------- neomodel/sync_/core.py | 46 ++++++++++++++++------------- neomodel/sync_/property_manager.py | 18 ++++++----- 6 files changed, 112 insertions(+), 76 deletions(-) diff --git a/neomodel/async_/cardinality.py b/neomodel/async_/cardinality.py index 17101cec..3fa7d585 100644 --- a/neomodel/async_/cardinality.py +++ b/neomodel/async_/cardinality.py @@ -1,16 +1,21 @@ +from typing import TYPE_CHECKING, Any, Optional + from neomodel.async_.relationship_manager import ( # pylint:disable=unused-import AsyncRelationshipManager, AsyncZeroOrMore, ) from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation +if TYPE_CHECKING: + from neomodel import AsyncStructuredNode, AsyncStructuredRel + class AsyncZeroOrOne(AsyncRelationshipManager): """A relationship to zero or one node.""" description = "zero or one relationship" - async def single(self): + async def single(self) -> Optional["AsyncStructuredNode"]: """ Return the associated node. @@ -23,11 +28,13 @@ async def single(self): raise CardinalityViolation(self, len(nodes)) return None - async def all(self): + async def all(self) -> list["AsyncStructuredNode"]: node = await self.single() return [node] if node else [] - async def connect(self, node, properties=None): + async def connect( + self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + ) -> "AsyncStructuredRel": """ Connect to a node. @@ -49,7 +56,7 @@ class AsyncOneOrMore(AsyncRelationshipManager): description = "one or more relationships" - async def single(self): + async def single(self) -> "AsyncStructuredNode": """ Fetch one of the related nodes @@ -60,7 +67,7 @@ async def single(self): return nodes[0] raise CardinalityViolation(self, "none") - async def all(self): + async def all(self) -> list["AsyncStructuredNode"]: """ Returns all related nodes. @@ -71,7 +78,7 @@ async def all(self): return nodes raise CardinalityViolation(self, "none") - async def disconnect(self, node): + async def disconnect(self, node: "AsyncStructuredNode") -> None: """ Disconnect node :param node: @@ -89,7 +96,7 @@ class AsyncOne(AsyncRelationshipManager): description = "one relationship" - async def single(self): + async def single(self) -> "AsyncStructuredNode": """ Return the associated node. @@ -102,7 +109,7 @@ async def single(self): raise CardinalityViolation(self, len(nodes)) raise CardinalityViolation(self, "none") - async def all(self): + async def all(self) -> list["AsyncStructuredNode"]: """ Return single node in an array @@ -110,17 +117,19 @@ async def all(self): """ return [await self.single()] - async def disconnect(self, node): + async def disconnect(self, node: "AsyncStructuredNode") -> None: raise AttemptedCardinalityViolation( "Cardinality one, cannot disconnect use reconnect." ) - async def disconnect_all(self): + async def disconnect_all(self) -> None: raise AttemptedCardinalityViolation( "Cardinality one, cannot disconnect_all use reconnect." ) - async def connect(self, node, properties=None): + async def connect( + self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + ) -> "AsyncStructuredRel": """ Connect a node diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 0ccabe11..f7e7ac66 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -65,7 +65,7 @@ def ensure_connection(func: Callable) -> Callable: """ - async def wrapper(self, *args: Any, **kwargs: Any) -> Callable: + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: # Sort out where to find url if hasattr(self, "db"): _db = self.db @@ -91,7 +91,7 @@ class AsyncDatabase(local): _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} - def __init__(self): + def __init__(self) -> None: self._active_transaction: Optional[AsyncTransaction] = None self.url: Optional[str] = None self.driver: Optional[AsyncDriver] = None @@ -768,7 +768,7 @@ async def install_all_labels(self, stdout: Optional[TextIO] = None) -> None: if not stdout or stdout is None: stdout = sys.stdout - def subsub(cls): # recursively return all subclasses + def subsub(cls: Any) -> list: # recursively return all subclasses subclasses = cls.__subclasses__() if not subclasses: # base case: no more subclasses return [] @@ -919,7 +919,7 @@ async def _create_node_vector_index( async def _create_node_constraint( self, target_cls: Any, property_name: str, stdout: TextIO, quiet: bool - ): + ) -> None: label = target_cls.__label__ constraint_name = f"constraint_unique_{label}_{property_name}" if not quiet: @@ -1196,7 +1196,7 @@ async def clear_neo4j_database( await db.clear_neo4j_database(clear_constraints, clear_indexes) -async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None): +async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). @@ -1230,7 +1230,7 @@ async def remove_all_labels(stdout: Optional[TextIO] = None) -> None: async def install_labels( - cls, quiet: bool = True, stdout: Optional[TextIO] = None + cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None ) -> None: deprecated( """ @@ -1365,10 +1365,10 @@ class NodeMeta(type): defined_properties: Callable[..., dict[str, Any]] def __new__( - mcs: "NodeMeta", name: str, bases: tuple[type, ...], namespace: dict[str, Any] + mcs: type, name: str, bases: tuple[type, ...], namespace: dict[str, Any] ) -> Any: namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) - cls = super().__new__(mcs, name, bases, namespace) + cls: NodeMeta = type.__new__(mcs, name, bases, namespace) cls.DoesNotExist._model_class = cls if hasattr(cls, "__abstract_node__"): @@ -1425,7 +1425,7 @@ def __new__( return cls -def build_class_registry(cls) -> None: +def build_class_registry(cls: Any) -> None: base_label_set = frozenset(cls.inherited_labels()) optional_label_set = set(cls.inherited_optional_labels()) @@ -1677,7 +1677,7 @@ async def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list :param lazy: False by default, specify True to get nodes with id only without the parameters. :rtype: list """ - lazy: bool = kwargs.get("lazy", False) + lazy: bool = bool(kwargs.get("lazy", False)) relationship = kwargs.get("relationship") # build merge query, make sure to update only explicitly specified properties @@ -1726,12 +1726,14 @@ async def cypher( """ self._pre_action_check("cypher") _params = params or {} + if self.element_id is None: + raise ValueError("Can't run cypher operation on unsaved node") element_id = await adb.parse_element_id(self.element_id) _params.update({"self": element_id}) return await adb.cypher_query(query, _params) @hooks - async def delete(self): + async def delete(self) -> bool: """ Delete a node and its relationships @@ -1746,7 +1748,7 @@ async def delete(self): return True @classmethod - async def get_or_create(cls, *props, **kwargs): + async def get_or_create(cls: Any, *props: tuple, **kwargs: dict[str, Any]) -> list: """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, this is an atomic operation. @@ -1769,7 +1771,7 @@ async def get_or_create(cls, *props, **kwargs): {"create": cls.deflate(p, skip_empty=True)} for p in props ] query, params = await cls._build_merge_query( - get_or_create_params, relationship=relationship, lazy=lazy + tuple(get_or_create_params), relationship=relationship, lazy=lazy ) if "streaming" in kwargs: @@ -1784,7 +1786,7 @@ async def get_or_create(cls, *props, **kwargs): return [cls.inflate(r[0]) for r in results[0]] @classmethod - def inflate(cls, node): + def inflate(cls: Any, node: Any) -> Any: """ Inflate a raw neo4j_driver node to a neomodel node :param node: @@ -1801,7 +1803,7 @@ def inflate(cls, node): return snode @classmethod - def inherited_labels(cls): + def inherited_labels(cls: Any) -> list[str]: """ Return list of labels from nodes class hierarchy. @@ -1814,7 +1816,7 @@ def inherited_labels(cls): ] @classmethod - def inherited_optional_labels(cls): + def inherited_optional_labels(cls: Any) -> list[str]: """ Return list of optional labels from nodes class hierarchy. @@ -1828,7 +1830,7 @@ def inherited_optional_labels(cls): if not hasattr(scls, "__abstract_node__") ] - async def labels(self): + async def labels(self) -> list[str]: """ Returns list of labels tied to the node from neo4j. @@ -1839,9 +1841,11 @@ async def labels(self): result = await self.cypher( f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self " "RETURN labels(n)" ) + if result is None or result[0] is None: + raise ValueError("Could not get labels, node may not exist") return result[0][0][0] - def _pre_action_check(self, action): + def _pre_action_check(self, action: str) -> None: if hasattr(self, "deleted") and self.deleted: raise ValueError( f"{self.__class__.__name__}.{action}() attempted on deleted node" @@ -1851,7 +1855,7 @@ def _pre_action_check(self, action): f"{self.__class__.__name__}.{action}() attempted on unsaved node" ) - async def refresh(self): + async def refresh(self) -> None: """ Reload the node from neo4j """ @@ -1870,7 +1874,7 @@ async def refresh(self): raise ValueError("Can't refresh unsaved node") @hooks - async def save(self): + async def save(self) -> "AsyncStructuredNode": """ Save the node to neo4j or raise an exception diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py index c17bd864..deeea5ed 100644 --- a/neomodel/async_/property_manager.py +++ b/neomodel/async_/property_manager.py @@ -1,12 +1,14 @@ import types from typing import Any +from neo4j.graph import Entity + from neomodel.exceptions import RequiredProperty from neomodel.properties import AliasProperty, Property -def display_for(key): - def display_choice(self): +def display_for(key: str) -> Any: + def display_choice(self: Any) -> Any: return getattr(self.__class__, key).choices[getattr(self, key)] return display_choice @@ -17,7 +19,7 @@ class AsyncPropertyManager: Common methods for handling properties on node and relationship objects. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: dict[str, Any]) -> None: properties = getattr(self, "__all_properties__", None) if properties is None: properties = self.defined_properties(rels=False, aliases=False).items() @@ -55,7 +57,7 @@ def __init__(self, **kwargs): setattr(self, name, property) @property - def __properties__(self): + def __properties__(self) -> dict[str, Any]: from neomodel.async_.relationship_manager import AsyncRelationshipManager return dict( @@ -73,7 +75,9 @@ def __properties__(self): ) @classmethod - def deflate(cls, properties, obj=None, skip_empty=False): + def deflate( + cls, properties: Any, obj: Any = None, skip_empty: bool = False + ) -> dict[str, Any]: """ Deflate the properties of a PropertyManager subclass (a user-defined StructuredNode or StructuredRel) so that it can be put into a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) for storage. properties @@ -97,7 +101,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): return deflated @classmethod - def inflate(cls, graph_entity): + def inflate(cls: Any, graph_entity: Entity) -> Any: """ Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance of cls. @@ -119,7 +123,7 @@ def inflate(cls, graph_entity): @classmethod def defined_properties( - cls, aliases=True, properties=True, rels=True + cls: Any, aliases: bool = True, properties: bool = True, rels: bool = True ) -> dict[str, Any]: from neomodel.async_.relationship_manager import AsyncRelationshipDefinition diff --git a/neomodel/sync_/cardinality.py b/neomodel/sync_/cardinality.py index 716d173f..051968d2 100644 --- a/neomodel/sync_/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,16 +1,21 @@ +from typing import TYPE_CHECKING, Any, Optional + from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) +if TYPE_CHECKING: + from neomodel import StructuredNode, StructuredRel + class ZeroOrOne(RelationshipManager): """A relationship to zero or one node.""" description = "zero or one relationship" - def single(self): + def single(self) -> Optional["StructuredNode"]: """ Return the associated node. @@ -23,11 +28,13 @@ def single(self): raise CardinalityViolation(self, len(nodes)) return None - def all(self): + def all(self) -> list["StructuredNode"]: node = self.single() return [node] if node else [] - def connect(self, node, properties=None): + def connect( + self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + ) -> "StructuredRel": """ Connect to a node. @@ -49,7 +56,7 @@ class OneOrMore(RelationshipManager): description = "one or more relationships" - def single(self): + def single(self) -> "StructuredNode": """ Fetch one of the related nodes @@ -60,7 +67,7 @@ def single(self): return nodes[0] raise CardinalityViolation(self, "none") - def all(self): + def all(self) -> list["StructuredNode"]: """ Returns all related nodes. @@ -71,7 +78,7 @@ def all(self): return nodes raise CardinalityViolation(self, "none") - def disconnect(self, node): + def disconnect(self, node: "StructuredNode") -> None: """ Disconnect node :param node: @@ -89,7 +96,7 @@ class One(RelationshipManager): description = "one relationship" - def single(self): + def single(self) -> "StructuredNode": """ Return the associated node. @@ -102,7 +109,7 @@ def single(self): raise CardinalityViolation(self, len(nodes)) raise CardinalityViolation(self, "none") - def all(self): + def all(self) -> list["StructuredNode"]: """ Return single node in an array @@ -110,17 +117,19 @@ def all(self): """ return [self.single()] - def disconnect(self, node): + def disconnect(self, node: "StructuredNode") -> None: raise AttemptedCardinalityViolation( "Cardinality one, cannot disconnect use reconnect." ) - def disconnect_all(self): + def disconnect_all(self) -> None: raise AttemptedCardinalityViolation( "Cardinality one, cannot disconnect_all use reconnect." ) - def connect(self, node, properties=None): + def connect( + self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + ) -> "StructuredRel": """ Connect a node diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 31987106..c6a7779d 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -65,7 +65,7 @@ def ensure_connection(func: Callable) -> Callable: """ - def wrapper(self, *args: Any, **kwargs: Any) -> Callable: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: # Sort out where to find url if hasattr(self, "db"): _db = self.db @@ -91,7 +91,7 @@ class Database(local): _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} - def __init__(self): + def __init__(self) -> None: self._active_transaction: Optional[Transaction] = None self.url: Optional[str] = None self.driver: Optional[Driver] = None @@ -763,7 +763,7 @@ def install_all_labels(self, stdout: Optional[TextIO] = None) -> None: if not stdout or stdout is None: stdout = sys.stdout - def subsub(cls): # recursively return all subclasses + def subsub(cls: Any) -> list: # recursively return all subclasses subclasses = cls.__subclasses__() if not subclasses: # base case: no more subclasses return [] @@ -914,7 +914,7 @@ def _create_node_vector_index( def _create_node_constraint( self, target_cls: Any, property_name: str, stdout: TextIO, quiet: bool - ): + ) -> None: label = target_cls.__label__ constraint_name = f"constraint_unique_{label}_{property_name}" if not quiet: @@ -1189,7 +1189,7 @@ def clear_neo4j_database( db.clear_neo4j_database(clear_constraints, clear_indexes) -def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None): +def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1222,7 +1222,9 @@ def remove_all_labels(stdout: Optional[TextIO] = None) -> None: db.remove_all_labels(stdout) -def install_labels(cls, quiet: bool = True, stdout: Optional[TextIO] = None) -> None: +def install_labels( + cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None +) -> None: deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1356,10 +1358,10 @@ class NodeMeta(type): defined_properties: Callable[..., dict[str, Any]] def __new__( - mcs: "NodeMeta", name: str, bases: tuple[type, ...], namespace: dict[str, Any] + mcs: type, name: str, bases: tuple[type, ...], namespace: dict[str, Any] ) -> Any: namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) - cls = super().__new__(mcs, name, bases, namespace) + cls: NodeMeta = type.__new__(mcs, name, bases, namespace) cls.DoesNotExist._model_class = cls if hasattr(cls, "__abstract_node__"): @@ -1416,7 +1418,7 @@ def __new__( return cls -def build_class_registry(cls) -> None: +def build_class_registry(cls: Any) -> None: base_label_set = frozenset(cls.inherited_labels()) optional_label_set = set(cls.inherited_optional_labels()) @@ -1666,7 +1668,7 @@ def create_or_update(cls, *props: tuple, **kwargs: dict[str, Any]) -> list: :param lazy: False by default, specify True to get nodes with id only without the parameters. :rtype: list """ - lazy: bool = kwargs.get("lazy", False) + lazy: bool = bool(kwargs.get("lazy", False)) relationship = kwargs.get("relationship") # build merge query, make sure to update only explicitly specified properties @@ -1715,12 +1717,14 @@ def cypher( """ self._pre_action_check("cypher") _params = params or {} + if self.element_id is None: + raise ValueError("Can't run cypher operation on unsaved node") element_id = db.parse_element_id(self.element_id) _params.update({"self": element_id}) return db.cypher_query(query, _params) @hooks - def delete(self): + def delete(self) -> bool: """ Delete a node and its relationships @@ -1735,7 +1739,7 @@ def delete(self): return True @classmethod - def get_or_create(cls, *props, **kwargs): + def get_or_create(cls: Any, *props: tuple, **kwargs: dict[str, Any]) -> list: """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, this is an atomic operation. @@ -1758,7 +1762,7 @@ def get_or_create(cls, *props, **kwargs): {"create": cls.deflate(p, skip_empty=True)} for p in props ] query, params = cls._build_merge_query( - get_or_create_params, relationship=relationship, lazy=lazy + tuple(get_or_create_params), relationship=relationship, lazy=lazy ) if "streaming" in kwargs: @@ -1773,7 +1777,7 @@ def get_or_create(cls, *props, **kwargs): return [cls.inflate(r[0]) for r in results[0]] @classmethod - def inflate(cls, node): + def inflate(cls: Any, node: Any) -> Any: """ Inflate a raw neo4j_driver node to a neomodel node :param node: @@ -1790,7 +1794,7 @@ def inflate(cls, node): return snode @classmethod - def inherited_labels(cls): + def inherited_labels(cls: Any) -> list[str]: """ Return list of labels from nodes class hierarchy. @@ -1803,7 +1807,7 @@ def inherited_labels(cls): ] @classmethod - def inherited_optional_labels(cls): + def inherited_optional_labels(cls: Any) -> list[str]: """ Return list of optional labels from nodes class hierarchy. @@ -1817,7 +1821,7 @@ def inherited_optional_labels(cls): if not hasattr(scls, "__abstract_node__") ] - def labels(self): + def labels(self) -> list[str]: """ Returns list of labels tied to the node from neo4j. @@ -1828,9 +1832,11 @@ def labels(self): result = self.cypher( f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" ) + if result is None or result[0] is None: + raise ValueError("Could not get labels, node may not exist") return result[0][0][0] - def _pre_action_check(self, action): + def _pre_action_check(self, action: str) -> None: if hasattr(self, "deleted") and self.deleted: raise ValueError( f"{self.__class__.__name__}.{action}() attempted on deleted node" @@ -1840,7 +1846,7 @@ def _pre_action_check(self, action): f"{self.__class__.__name__}.{action}() attempted on unsaved node" ) - def refresh(self): + def refresh(self) -> None: """ Reload the node from neo4j """ @@ -1859,7 +1865,7 @@ def refresh(self): raise ValueError("Can't refresh unsaved node") @hooks - def save(self): + def save(self) -> "StructuredNode": """ Save the node to neo4j or raise an exception diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py index 08f1e900..160475cb 100644 --- a/neomodel/sync_/property_manager.py +++ b/neomodel/sync_/property_manager.py @@ -1,12 +1,14 @@ import types from typing import Any +from neo4j.graph import Entity + from neomodel.exceptions import RequiredProperty from neomodel.properties import AliasProperty, Property -def display_for(key): - def display_choice(self): +def display_for(key: str) -> Any: + def display_choice(self: Any) -> Any: return getattr(self.__class__, key).choices[getattr(self, key)] return display_choice @@ -17,7 +19,7 @@ class PropertyManager: Common methods for handling properties on node and relationship objects. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: dict[str, Any]) -> None: properties = getattr(self, "__all_properties__", None) if properties is None: properties = self.defined_properties(rels=False, aliases=False).items() @@ -55,7 +57,7 @@ def __init__(self, **kwargs): setattr(self, name, property) @property - def __properties__(self): + def __properties__(self) -> dict[str, Any]: from neomodel.sync_.relationship_manager import RelationshipManager return dict( @@ -73,7 +75,9 @@ def __properties__(self): ) @classmethod - def deflate(cls, properties, obj=None, skip_empty=False): + def deflate( + cls, properties: Any, obj: Any = None, skip_empty: bool = False + ) -> dict[str, Any]: """ Deflate the properties of a PropertyManager subclass (a user-defined StructuredNode or StructuredRel) so that it can be put into a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) for storage. properties @@ -97,7 +101,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): return deflated @classmethod - def inflate(cls, graph_entity): + def inflate(cls: Any, graph_entity: Entity) -> Any: """ Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance of cls. @@ -119,7 +123,7 @@ def inflate(cls, graph_entity): @classmethod def defined_properties( - cls, aliases=True, properties=True, rels=True + cls: Any, aliases: bool = True, properties: bool = True, rels: bool = True ) -> dict[str, Any]: from neomodel.sync_.relationship_manager import RelationshipDefinition From eff6e033c27a3c89fbd42d129bddefa88b9764e5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 13 Dec 2024 14:48:19 +0100 Subject: [PATCH 15/20] Type hints in match file --- neomodel/async_/core.py | 6 +- neomodel/async_/match.py | 187 +++++++++++++++--------- neomodel/async_/relationship_manager.py | 10 +- neomodel/properties.py | 4 +- neomodel/sync_/core.py | 6 +- neomodel/sync_/match.py | 185 ++++++++++++++--------- neomodel/sync_/relationship_manager.py | 10 +- 7 files changed, 254 insertions(+), 154 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index f7e7ac66..a8ad8c62 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -581,7 +581,11 @@ async def get_id_method(self) -> str: else: return "elementId" - async def parse_element_id(self, element_id: str) -> Union[str, int]: + async def parse_element_id(self, element_id: Optional[str]) -> Union[str, int]: + if element_id is None: + raise ValueError( + "Unable to parse element id, are you sure this element has been saved ?" + ) db_version = await self.database_version if db_version is None: raise RuntimeError( diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 23ce40d0..570f63e6 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -2,7 +2,7 @@ import re import string from dataclasses import dataclass -from typing import Any +from typing import Any, AsyncIterator from typing import Optional as TOptional from typing import Tuple, Union @@ -19,14 +19,14 @@ def _rel_helper( - lhs, - rhs, - ident=None, - relation_type=None, - direction=None, - relation_properties=None, - **kwargs, # NOSONAR -): + lhs: str, + rhs: str, + ident: TOptional[str] = None, + relation_type: TOptional[str] = None, + direction: TOptional[int] = None, + relation_properties: TOptional[dict] = None, + **kwargs: dict[str, Any], # NOSONAR +) -> str: """ Generate a relationship matching string, with specified parameters. Examples: @@ -83,14 +83,14 @@ def _rel_helper( def _rel_merge_helper( - lhs, - rhs, - ident="neomodelident", - relation_type=None, - direction=None, - relation_properties=None, - **kwargs, # NOSONAR -): + lhs: str, + rhs: str, + ident: str = "neomodelident", + relation_type: TOptional[str] = None, + direction: TOptional[int] = None, + relation_properties: TOptional[dict] = None, + **kwargs: dict[str, Any], # NOSONAR +) -> str: """ Generate a relationship merging string, with specified parameters. Examples: @@ -204,7 +204,9 @@ def _rel_merge_helper( path_split_regex = re.compile(r"__(?!_)|\|") -def install_traversals(cls, node_set): +def install_traversals( + cls: type[AsyncStructuredNode], node_set: "AsyncNodeSet" +) -> None: """ For a StructuredNode class install Traversal objects for each relationship definition on a NodeSet instance @@ -255,7 +257,12 @@ def _handle_special_operators( def _deflate_value( - cls, property_obj: Property, key: str, value: str, operator: str, prop: str + cls: type[AsyncStructuredNode], + property_obj: Property, + key: str, + value: str, + operator: str, + prop: str, ) -> Tuple[str, str, str]: if isinstance(property_obj, AliasProperty): prop = property_obj.aliased_to() @@ -269,7 +276,9 @@ def _deflate_value( return deflated_value, operator, prop -def _initialize_filter_args_variables(cls, key: str): +def _initialize_filter_args_variables( + cls: type[AsyncStructuredNode], key: str +) -> Tuple[type[AsyncStructuredNode], None, None, str, bool, str]: current_class = cls current_rel_model = None leaf_prop = None @@ -277,10 +286,19 @@ def _initialize_filter_args_variables(cls, key: str): is_rel_property = "|" in key prop = key - return current_class, current_rel_model, leaf_prop, operator, is_rel_property, prop + return ( + current_class, + current_rel_model, + leaf_prop, + operator, + is_rel_property, + prop, + ) -def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: +def _process_filter_key( + cls: type[AsyncStructuredNode], key: str +) -> Tuple[Property, str, str]: ( current_class, current_rel_model, @@ -313,6 +331,8 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: ) leaf_prop = part + if leaf_prop is None: + raise ValueError(f"Badly formed filter, no property found in {key}") if is_rel_property and current_rel_model: property_obj = getattr(current_rel_model, leaf_prop) else: @@ -321,7 +341,7 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: return property_obj, operator, prop -def process_filter_args(cls, kwargs) -> dict: +def process_filter_args(cls: type[AsyncStructuredNode], kwargs: dict[str, Any]) -> dict: """ loop through properties in filter parameters check they match class definition deflate them and convert into something easy to generate cypher from @@ -340,7 +360,9 @@ def process_filter_args(cls, kwargs) -> dict: return output -def process_has_args(cls, kwargs): +def process_has_args( + cls: type[AsyncStructuredNode], kwargs: dict[str, Any] +) -> tuple[dict, dict]: """ loop through has parameters check they correspond to class rels defined """ @@ -415,7 +437,9 @@ def __init__( class AsyncQueryBuilder: - def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: + def __init__( + self, node_set: "AsyncBaseSet", subquery_namespace: TOptional[str] = None + ) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: dict = {} @@ -424,7 +448,9 @@ def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: self._subquery_namespace: TOptional[str] = subquery_namespace async def build_ast(self) -> "AsyncQueryBuilder": - if hasattr(self.node_set, "relations_to_fetch"): + if isinstance(self.node_set, AsyncNodeSet) and hasattr( + self.node_set, "relations_to_fetch" + ): for relation in self.node_set.relations_to_fetch: self.build_traversal_from_path(relation, self.node_set.source) @@ -437,7 +463,9 @@ async def build_ast(self) -> "AsyncQueryBuilder": return self - async def build_source(self, source) -> str: + async def build_source( + self, source: Union["AsyncTraversal", "AsyncNodeSet", AsyncStructuredNode, Any] + ) -> str: if isinstance(source, AsyncTraversal): return await self.build_traversal(source) if isinstance(source, AsyncNodeSet): @@ -455,10 +483,10 @@ async def build_source(self, source) -> str: if source.filters or source.q_filters: self.build_where_stmt( - ident, - source.filters, - source.q_filters, + ident=ident, + filters=source.filters, source_class=source.source_class, + q_filters=source.q_filters, ) return ident @@ -499,7 +527,7 @@ def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None: order_by.append(f"{result[0]}.{prop}") self._ast.order_by = order_by - async def build_traversal(self, traversal) -> str: + async def build_traversal(self, traversal: "AsyncTraversal") -> str: """ traverse a relationship from a node to a set of nodes """ @@ -523,11 +551,11 @@ async def build_traversal(self, traversal) -> str: self._ast.match.append(stmt) if traversal.filters: - self.build_where_stmt(rel_ident, traversal.filters) + self.build_where_stmt(rel_ident, traversal.filters, traversal.source_class) return traversal_ident - def _additional_return(self, name: str): + def _additional_return(self, name: str) -> None: if ( not self._ast.additional_return or name not in self._ast.additional_return ) and name != self._ast.return_clause: @@ -536,7 +564,7 @@ def _additional_return(self, name: str): self._ast.additional_return.append(name) def build_traversal_from_path( - self, relation: dict, source_class + self, relation: dict, source_class: Any ) -> Tuple[str, Any]: path: str = relation["path"] stmt: str = "" @@ -622,7 +650,7 @@ def build_traversal_from_path( return existing_rhs_name, relationship.definition["node_class"] - async def build_node(self, node): + async def build_node(self, node: AsyncStructuredNode) -> str: ident = node.__class__.__name__.lower() place_holder = self._register_place_holder(ident) @@ -636,7 +664,7 @@ async def build_node(self, node): self._ast.result_class = node.__class__ return ident - def build_label(self, ident, cls) -> str: + def build_label(self, ident: str, cls: type[AsyncStructuredNode]) -> str: """ match nodes by a label """ @@ -650,7 +678,7 @@ def build_label(self, ident, cls) -> str: self._ast.result_class = cls return ident - def build_additional_match(self, ident, node_set): + def build_additional_match(self, ident: str, node_set: "AsyncNodeSet") -> None: """ handle additional matches supplied by 'has()' calls """ @@ -682,7 +710,9 @@ def _register_place_holder(self, key: str) -> str: place_holder = f"{self._subquery_namespace}_{place_holder}" return place_holder - def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: + def _parse_path( + self, source_class: type[AsyncStructuredNode], prop: str + ) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop if is_rel_filter: path, prop = prop.rsplit("|", 1) @@ -723,7 +753,11 @@ def _finalize_filter_statement( return statement def _build_filter_statements( - self, ident: str, filters, target: list[str], source_class + self, + ident: str, + filters: dict[str, tuple], + target: list[str], + source_class: type[AsyncStructuredNode], ) -> None: for prop, op_and_val in filters.items(): path = None @@ -739,7 +773,9 @@ def _build_filter_statements( statement = self._finalize_filter_statement(operator, ident, prop, val) target.append(statement) - def _parse_q_filters(self, ident, q, source_class) -> str: + def _parse_q_filters( + self, ident: str, q: Union[QBase, Any], source_class: type[AsyncStructuredNode] + ) -> str: target = [] for child in q.children: if isinstance(child, QBase): @@ -757,7 +793,11 @@ def _parse_q_filters(self, ident, q, source_class) -> str: return ret def build_where_stmt( - self, ident: str, filters, q_filters=None, source_class=None + self, + ident: str, + filters: list, + source_class: type[AsyncStructuredNode], + q_filters: Union[QBase, Any, None] = None, ) -> None: """ construct a where statement from some filters @@ -891,7 +931,7 @@ def build_query(self) -> str: query += " CALL {" if subquery["initial_context"]: query += " WITH " - context: List[str] = [] + context: list[str] = [] for var in subquery["initial_context"]: if isinstance(var, (NodeNameResolver, RelationNameResolver)): context.append(var.resolve(self)) @@ -949,7 +989,7 @@ def build_query(self) -> str: return query - async def _count(self): + async def _count(self) -> int: self._ast.is_count = True # If we return a count with pagination, pagination has to happen before RETURN # Like : WITH my_var SKIP 10 LIMIT 10 RETURN count(my_var) @@ -969,7 +1009,7 @@ async def _count(self): results, _ = await adb.cypher_query(query, self._query_params) return int(results[0][0]) - async def _contains(self, node_element_id): + async def _contains(self, node_element_id: TOptional[Union[str, int]]) -> bool: # inject id = into ast if not self._ast.return_clause and self._ast.additional_return: self._ast.return_clause = self._ast.additional_return[0] @@ -983,7 +1023,7 @@ async def _contains(self, node_element_id): self._query_params[place_holder] = node_element_id return await self._count() >= 1 - async def _execute(self, lazy: bool = False, dict_output: bool = False): + async def _execute(self, lazy: bool = False, dict_output: bool = False) -> Any: if lazy: # inject id() into return or return_set if self._ast.return_clause: @@ -1028,7 +1068,7 @@ class AsyncBaseSet: query_cls = AsyncQueryBuilder source_class: type[AsyncStructuredNode] - async def all(self, lazy=False): + async def all(self, lazy: bool = False) -> list: """ Return all nodes belonging to the set :param lazy: False by default, specify True to get nodes with id only without the parameters. @@ -1041,12 +1081,12 @@ async def all(self, lazy=False): ] # Collect all nodes asynchronously return results - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator: ast = await self.query_cls(self).build_ast() async for item in ast._execute(): yield item - async def get_len(self): + async def get_len(self) -> int: ast = await self.query_cls(self).build_ast() return await ast._count() @@ -1068,7 +1108,7 @@ async def check_nonzero(self) -> bool: """ return await self.check_bool() - async def check_contains(self, obj): + async def check_contains(self, obj: Union[AsyncStructuredNode, Any]) -> bool: if isinstance(obj, AsyncStructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: ast = await self.query_cls(self).build_ast() @@ -1078,7 +1118,7 @@ async def check_contains(self, obj): raise ValueError("Expecting StructuredNode instance") - async def get_item(self, key): + async def get_item(self, key: Union[int, slice]) -> TOptional["AsyncBaseSet"]: if isinstance(key, slice): if key.stop and key.start: self.limit = key.stop - key.start @@ -1223,7 +1263,7 @@ class RawCypher: statement: str - def __post_init__(self): + def __post_init__(self) -> None: if CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR.search(self.statement): raise ValueError( "RawCypher: Do not include any action that has side effect" @@ -1238,7 +1278,7 @@ class AsyncNodeSet(AsyncBaseSet): A class representing as set of nodes matching common query parameters """ - def __init__(self, source) -> None: + def __init__(self, source: Any) -> None: self.source = source # could be a Traverse object or a node class if isinstance(source, AsyncTraversal): self.source_class = source.target_class @@ -1265,10 +1305,12 @@ def __init__(self, source) -> None: self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] - def __await__(self): + def __await__(self) -> Any: return self.all().__await__() - async def _get(self, limit=None, lazy=False, **kwargs): + async def _get( + self, limit: TOptional[int] = None, lazy: bool = False, **kwargs: dict[str, Any] + ) -> list: self.filter(**kwargs) if limit: self.limit = limit @@ -1276,7 +1318,7 @@ async def _get(self, limit=None, lazy=False, **kwargs): results = [node async for node in ast._execute(lazy)] return results - async def get(self, lazy=False, **kwargs): + async def get(self, lazy: bool = False, **kwargs: Any) -> Any: """ Retrieve one node from the set matching supplied parameters :param lazy: False by default, specify True to get nodes with id only without the parameters. @@ -1290,7 +1332,7 @@ async def get(self, lazy=False, **kwargs): raise self.source_class.DoesNotExist(repr(kwargs)) return result[0] - async def get_or_none(self, **kwargs): + async def get_or_none(self, **kwargs: Any) -> Any: """ Retrieve a node from the set matching supplied parameters or return none @@ -1302,7 +1344,7 @@ async def get_or_none(self, **kwargs): except self.source_class.DoesNotExist: return None - async def first(self, **kwargs): + async def first(self, **kwargs: Any) -> Any: """ Retrieve the first node from the set matching supplied parameters @@ -1315,7 +1357,7 @@ async def first(self, **kwargs): else: raise self.source_class.DoesNotExist(repr(kwargs)) - async def first_or_none(self, **kwargs): + async def first_or_none(self, **kwargs: Any) -> Any: """ Retrieve the first node from the set matching supplied parameters or return none @@ -1328,7 +1370,7 @@ async def first_or_none(self, **kwargs): pass return None - def filter(self, *args, **kwargs) -> "AsyncBaseSet": + def filter(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet": """ Apply filters to the existing nodes in the set. @@ -1367,7 +1409,7 @@ def filter(self, *args, **kwargs) -> "AsyncBaseSet": self.q_filters = Q(self.q_filters & Q(*args, **kwargs)) return self - def exclude(self, *args, **kwargs): + def exclude(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet": """ Exclude nodes from the NodeSet via filters. @@ -1378,13 +1420,13 @@ def exclude(self, *args, **kwargs): self.q_filters = Q(self.q_filters & ~Q(*args, **kwargs)) return self - def has(self, **kwargs): + def has(self, **kwargs: Any) -> "AsyncBaseSet": must_match, dont_match = process_has_args(self.source_class, kwargs) self.must_match.update(must_match) self.dont_match.update(dont_match) return self - def order_by(self, *props): + def order_by(self, *props: Any) -> "AsyncBaseSet": """ Order by properties. Prepend with minus to do descending. Pass None to remove ordering. @@ -1422,7 +1464,7 @@ def _register_relation_to_fetch( relation_def: Any, alias: TOptional[str] = None, include_in_return: bool = True, - ): + ) -> dict: if isinstance(relation_def, Optional): item = {"path": relation_def.relation, "optional": True} else: @@ -1433,7 +1475,7 @@ def _register_relation_to_fetch( item["alias"] = alias return item - def fetch_relations(self, *relation_names): + def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet": """Specify a set of relations to traverse and return.""" relations = [] for relation_name in relation_names: @@ -1441,7 +1483,9 @@ def fetch_relations(self, *relation_names): self.relations_to_fetch = relations return self - def traverse_relations(self, *relation_names, **aliased_relation_names): + def traverse_relations( + self, *relation_names: tuple[str, ...], **aliased_relation_names: dict + ) -> "AsyncNodeSet": """Specify a set of relations to traverse only.""" relations = [] for relation_name in relation_names: @@ -1458,10 +1502,13 @@ def traverse_relations(self, *relation_names, **aliased_relation_names): self.relations_to_fetch = relations return self - def annotate(self, *vars, **aliased_vars): + def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "AsyncNodeSet": """Annotate node set results with extra variables.""" - def register_extra_var(vardef, varname: Union[str, None] = None): + def register_extra_var( + vardef: Union[AggregatingFunction, ScalarFunction, Any], + varname: Union[str, None] = None, + ) -> None: if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results.append( {"vardef": vardef, "alias": varname if varname else ""} @@ -1476,7 +1523,7 @@ def register_extra_var(vardef, varname: Union[str, None] = None): return self - def _to_subgraph(self, root_node, other_nodes, subgraph): + def _to_subgraph(self, root_node: Any, other_nodes: Any, subgraph: dict) -> Any: """Recursive method to build root_node's relation graph from subgraph.""" root_node._relations = {} for name, relation_def in subgraph.items(): @@ -1648,7 +1695,7 @@ class AsyncTraversal(AsyncBaseSet): name: str filters: list - def __await__(self): + def __await__(self) -> Any: return self.all().__await__() def __init__(self, source: Any, name: str, definition: dict) -> None: @@ -1683,7 +1730,7 @@ def __init__(self, source: Any, name: str, definition: dict) -> None: self.name = name self.filters: list = [] - def match(self, **kwargs): + def match(self, **kwargs: Any) -> "AsyncTraversal": """ Traverse relationships with properties matching the given parameters. diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index a3d9b27a..9bbc5eed 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -2,7 +2,7 @@ import inspect import sys from importlib import import_module -from typing import Any +from typing import Any, Callable from neomodel.async_.core import adb from neomodel.async_.match import ( @@ -23,11 +23,11 @@ # check source node is saved and not deleted -def check_source(fn): +def check_source(fn: Callable) -> Callable: fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ @functools.wraps(fn) - def checker(self, *args, **kwargs): + def checker(self: Any, *args: Any, **kwargs: Any) -> Callable: self.source._pre_action_check(self.name + "." + fn_name) return fn(self, *args, **kwargs) @@ -35,7 +35,7 @@ def checker(self, *args, **kwargs): # checks if obj is a direct subclass, 1 level -def is_direct_subclass(obj, classinfo): +def is_direct_subclass(obj: Any, classinfo: Any) -> bool: for base in obj.__bases__: if base == classinfo: return True @@ -61,7 +61,7 @@ def __init__(self, source: Any, key: str, definition: dict): self.name = key self.definition = definition - def __str__(self): + def __str__(self) -> str: direction = "either" if self.definition["direction"] == OUTGOING: direction = "a outgoing" diff --git a/neomodel/properties.py b/neomodel/properties.py index 27da64da..59d46c62 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -180,11 +180,11 @@ def is_indexed(self) -> bool: return self.unique_index or self.index @abstractmethod - def inflate(self, value: Any, rethrow: bool) -> Any: + def inflate(self, value: Any, rethrow: bool = False) -> Any: pass @abstractmethod - def deflate(self, value: Any, rethrow: bool) -> Any: + def deflate(self, value: Any, rethrow: bool = False) -> Any: pass diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index c6a7779d..b66a9ab6 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -581,7 +581,11 @@ def get_id_method(self) -> str: else: return "elementId" - def parse_element_id(self, element_id: str) -> Union[str, int]: + def parse_element_id(self, element_id: Optional[str]) -> Union[str, int]: + if element_id is None: + raise ValueError( + "Unable to parse element id, are you sure this element has been saved ?" + ) db_version = self.database_version if db_version is None: raise RuntimeError( diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 463dfa5a..fe7dc33b 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -2,7 +2,7 @@ import re import string from dataclasses import dataclass -from typing import Any +from typing import Any, Iterator from typing import Optional as TOptional from typing import Tuple, Union @@ -19,14 +19,14 @@ def _rel_helper( - lhs, - rhs, - ident=None, - relation_type=None, - direction=None, - relation_properties=None, - **kwargs, # NOSONAR -): + lhs: str, + rhs: str, + ident: TOptional[str] = None, + relation_type: TOptional[str] = None, + direction: TOptional[int] = None, + relation_properties: TOptional[dict] = None, + **kwargs: dict[str, Any], # NOSONAR +) -> str: """ Generate a relationship matching string, with specified parameters. Examples: @@ -83,14 +83,14 @@ def _rel_helper( def _rel_merge_helper( - lhs, - rhs, - ident="neomodelident", - relation_type=None, - direction=None, - relation_properties=None, - **kwargs, # NOSONAR -): + lhs: str, + rhs: str, + ident: str = "neomodelident", + relation_type: TOptional[str] = None, + direction: TOptional[int] = None, + relation_properties: TOptional[dict] = None, + **kwargs: dict[str, Any], # NOSONAR +) -> str: """ Generate a relationship merging string, with specified parameters. Examples: @@ -204,7 +204,7 @@ def _rel_merge_helper( path_split_regex = re.compile(r"__(?!_)|\|") -def install_traversals(cls, node_set): +def install_traversals(cls: type[StructuredNode], node_set: "NodeSet") -> None: """ For a StructuredNode class install Traversal objects for each relationship definition on a NodeSet instance @@ -255,7 +255,12 @@ def _handle_special_operators( def _deflate_value( - cls, property_obj: Property, key: str, value: str, operator: str, prop: str + cls: type[StructuredNode], + property_obj: Property, + key: str, + value: str, + operator: str, + prop: str, ) -> Tuple[str, str, str]: if isinstance(property_obj, AliasProperty): prop = property_obj.aliased_to() @@ -269,7 +274,9 @@ def _deflate_value( return deflated_value, operator, prop -def _initialize_filter_args_variables(cls, key: str): +def _initialize_filter_args_variables( + cls: type[StructuredNode], key: str +) -> Tuple[type[StructuredNode], None, None, str, bool, str]: current_class = cls current_rel_model = None leaf_prop = None @@ -277,10 +284,19 @@ def _initialize_filter_args_variables(cls, key: str): is_rel_property = "|" in key prop = key - return current_class, current_rel_model, leaf_prop, operator, is_rel_property, prop + return ( + current_class, + current_rel_model, + leaf_prop, + operator, + is_rel_property, + prop, + ) -def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: +def _process_filter_key( + cls: type[StructuredNode], key: str +) -> Tuple[Property, str, str]: ( current_class, current_rel_model, @@ -313,6 +329,8 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: ) leaf_prop = part + if leaf_prop is None: + raise ValueError(f"Badly formed filter, no property found in {key}") if is_rel_property and current_rel_model: property_obj = getattr(current_rel_model, leaf_prop) else: @@ -321,7 +339,7 @@ def _process_filter_key(cls, key: str) -> Tuple[Property, str, str]: return property_obj, operator, prop -def process_filter_args(cls, kwargs) -> dict: +def process_filter_args(cls: type[StructuredNode], kwargs: dict[str, Any]) -> dict: """ loop through properties in filter parameters check they match class definition deflate them and convert into something easy to generate cypher from @@ -340,7 +358,9 @@ def process_filter_args(cls, kwargs) -> dict: return output -def process_has_args(cls, kwargs): +def process_has_args( + cls: type[StructuredNode], kwargs: dict[str, Any] +) -> tuple[dict, dict]: """ loop through has parameters check they correspond to class rels defined """ @@ -415,7 +435,9 @@ def __init__( class QueryBuilder: - def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: + def __init__( + self, node_set: "BaseSet", subquery_namespace: TOptional[str] = None + ) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: dict = {} @@ -424,7 +446,9 @@ def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: self._subquery_namespace: TOptional[str] = subquery_namespace def build_ast(self) -> "QueryBuilder": - if hasattr(self.node_set, "relations_to_fetch"): + if isinstance(self.node_set, NodeSet) and hasattr( + self.node_set, "relations_to_fetch" + ): for relation in self.node_set.relations_to_fetch: self.build_traversal_from_path(relation, self.node_set.source) @@ -437,7 +461,9 @@ def build_ast(self) -> "QueryBuilder": return self - def build_source(self, source) -> str: + def build_source( + self, source: Union["Traversal", "NodeSet", StructuredNode, Any] + ) -> str: if isinstance(source, Traversal): return self.build_traversal(source) if isinstance(source, NodeSet): @@ -455,10 +481,10 @@ def build_source(self, source) -> str: if source.filters or source.q_filters: self.build_where_stmt( - ident, - source.filters, - source.q_filters, + ident=ident, + filters=source.filters, source_class=source.source_class, + q_filters=source.q_filters, ) return ident @@ -499,7 +525,7 @@ def build_order_by(self, ident: str, source: "NodeSet") -> None: order_by.append(f"{result[0]}.{prop}") self._ast.order_by = order_by - def build_traversal(self, traversal) -> str: + def build_traversal(self, traversal: "Traversal") -> str: """ traverse a relationship from a node to a set of nodes """ @@ -523,11 +549,11 @@ def build_traversal(self, traversal) -> str: self._ast.match.append(stmt) if traversal.filters: - self.build_where_stmt(rel_ident, traversal.filters) + self.build_where_stmt(rel_ident, traversal.filters, traversal.source_class) return traversal_ident - def _additional_return(self, name: str): + def _additional_return(self, name: str) -> None: if ( not self._ast.additional_return or name not in self._ast.additional_return ) and name != self._ast.return_clause: @@ -536,7 +562,7 @@ def _additional_return(self, name: str): self._ast.additional_return.append(name) def build_traversal_from_path( - self, relation: dict, source_class + self, relation: dict, source_class: Any ) -> Tuple[str, Any]: path: str = relation["path"] stmt: str = "" @@ -622,7 +648,7 @@ def build_traversal_from_path( return existing_rhs_name, relationship.definition["node_class"] - def build_node(self, node): + def build_node(self, node: StructuredNode) -> str: ident = node.__class__.__name__.lower() place_holder = self._register_place_holder(ident) @@ -636,7 +662,7 @@ def build_node(self, node): self._ast.result_class = node.__class__ return ident - def build_label(self, ident, cls) -> str: + def build_label(self, ident: str, cls: type[StructuredNode]) -> str: """ match nodes by a label """ @@ -650,7 +676,7 @@ def build_label(self, ident, cls) -> str: self._ast.result_class = cls return ident - def build_additional_match(self, ident, node_set): + def build_additional_match(self, ident: str, node_set: "NodeSet") -> None: """ handle additional matches supplied by 'has()' calls """ @@ -682,7 +708,9 @@ def _register_place_holder(self, key: str) -> str: place_holder = f"{self._subquery_namespace}_{place_holder}" return place_holder - def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: + def _parse_path( + self, source_class: type[StructuredNode], prop: str + ) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop if is_rel_filter: path, prop = prop.rsplit("|", 1) @@ -723,7 +751,11 @@ def _finalize_filter_statement( return statement def _build_filter_statements( - self, ident: str, filters, target: list[str], source_class + self, + ident: str, + filters: dict[str, tuple], + target: list[str], + source_class: type[StructuredNode], ) -> None: for prop, op_and_val in filters.items(): path = None @@ -739,7 +771,9 @@ def _build_filter_statements( statement = self._finalize_filter_statement(operator, ident, prop, val) target.append(statement) - def _parse_q_filters(self, ident, q, source_class) -> str: + def _parse_q_filters( + self, ident: str, q: Union[QBase, Any], source_class: type[StructuredNode] + ) -> str: target = [] for child in q.children: if isinstance(child, QBase): @@ -757,7 +791,11 @@ def _parse_q_filters(self, ident, q, source_class) -> str: return ret def build_where_stmt( - self, ident: str, filters, q_filters=None, source_class=None + self, + ident: str, + filters: list, + source_class: type[StructuredNode], + q_filters: Union[QBase, Any, None] = None, ) -> None: """ construct a where statement from some filters @@ -891,7 +929,7 @@ def build_query(self) -> str: query += " CALL {" if subquery["initial_context"]: query += " WITH " - context: List[str] = [] + context: list[str] = [] for var in subquery["initial_context"]: if isinstance(var, (NodeNameResolver, RelationNameResolver)): context.append(var.resolve(self)) @@ -949,7 +987,7 @@ def build_query(self) -> str: return query - def _count(self): + def _count(self) -> int: self._ast.is_count = True # If we return a count with pagination, pagination has to happen before RETURN # Like : WITH my_var SKIP 10 LIMIT 10 RETURN count(my_var) @@ -969,7 +1007,7 @@ def _count(self): results, _ = db.cypher_query(query, self._query_params) return int(results[0][0]) - def _contains(self, node_element_id): + def _contains(self, node_element_id: TOptional[Union[str, int]]) -> bool: # inject id = into ast if not self._ast.return_clause and self._ast.additional_return: self._ast.return_clause = self._ast.additional_return[0] @@ -981,7 +1019,7 @@ def _contains(self, node_element_id): self._query_params[place_holder] = node_element_id return self._count() >= 1 - def _execute(self, lazy: bool = False, dict_output: bool = False): + def _execute(self, lazy: bool = False, dict_output: bool = False) -> Any: if lazy: # inject id() into return or return_set if self._ast.return_clause: @@ -1026,7 +1064,7 @@ class BaseSet: query_cls = QueryBuilder source_class: type[StructuredNode] - def all(self, lazy=False): + def all(self, lazy: bool = False) -> list: """ Return all nodes belonging to the set :param lazy: False by default, specify True to get nodes with id only without the parameters. @@ -1039,12 +1077,12 @@ def all(self, lazy=False): ] # Collect all nodes asynchronously return results - def __iter__(self): + def __iter__(self) -> Iterator: ast = self.query_cls(self).build_ast() for item in ast._execute(): yield item - def __len__(self): + def __len__(self) -> int: ast = self.query_cls(self).build_ast() return ast._count() @@ -1066,7 +1104,7 @@ def __nonzero__(self) -> bool: """ return self.__bool__() - def __contains__(self, obj): + def __contains__(self, obj: Union[StructuredNode, Any]) -> bool: if isinstance(obj, StructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: ast = self.query_cls(self).build_ast() @@ -1076,7 +1114,7 @@ def __contains__(self, obj): raise ValueError("Expecting StructuredNode instance") - def __getitem__(self, key): + def __getitem__(self, key: Union[int, slice]) -> TOptional["BaseSet"]: if isinstance(key, slice): if key.stop and key.start: self.limit = key.stop - key.start @@ -1221,7 +1259,7 @@ class RawCypher: statement: str - def __post_init__(self): + def __post_init__(self) -> None: if CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR.search(self.statement): raise ValueError( "RawCypher: Do not include any action that has side effect" @@ -1236,7 +1274,7 @@ class NodeSet(BaseSet): A class representing as set of nodes matching common query parameters """ - def __init__(self, source) -> None: + def __init__(self, source: Any) -> None: self.source = source # could be a Traverse object or a node class if isinstance(source, Traversal): self.source_class = source.target_class @@ -1263,10 +1301,12 @@ def __init__(self, source) -> None: self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] - def __await__(self): + def __await__(self) -> Any: return self.all().__await__() - def _get(self, limit=None, lazy=False, **kwargs): + def _get( + self, limit: TOptional[int] = None, lazy: bool = False, **kwargs: dict[str, Any] + ) -> list: self.filter(**kwargs) if limit: self.limit = limit @@ -1274,7 +1314,7 @@ def _get(self, limit=None, lazy=False, **kwargs): results = [node for node in ast._execute(lazy)] return results - def get(self, lazy=False, **kwargs): + def get(self, lazy: bool = False, **kwargs: Any) -> Any: """ Retrieve one node from the set matching supplied parameters :param lazy: False by default, specify True to get nodes with id only without the parameters. @@ -1288,7 +1328,7 @@ def get(self, lazy=False, **kwargs): raise self.source_class.DoesNotExist(repr(kwargs)) return result[0] - def get_or_none(self, **kwargs): + def get_or_none(self, **kwargs: Any) -> Any: """ Retrieve a node from the set matching supplied parameters or return none @@ -1300,7 +1340,7 @@ def get_or_none(self, **kwargs): except self.source_class.DoesNotExist: return None - def first(self, **kwargs): + def first(self, **kwargs: Any) -> Any: """ Retrieve the first node from the set matching supplied parameters @@ -1313,7 +1353,7 @@ def first(self, **kwargs): else: raise self.source_class.DoesNotExist(repr(kwargs)) - def first_or_none(self, **kwargs): + def first_or_none(self, **kwargs: Any) -> Any: """ Retrieve the first node from the set matching supplied parameters or return none @@ -1326,7 +1366,7 @@ def first_or_none(self, **kwargs): pass return None - def filter(self, *args, **kwargs) -> "BaseSet": + def filter(self, *args: Any, **kwargs: Any) -> "BaseSet": """ Apply filters to the existing nodes in the set. @@ -1365,7 +1405,7 @@ def filter(self, *args, **kwargs) -> "BaseSet": self.q_filters = Q(self.q_filters & Q(*args, **kwargs)) return self - def exclude(self, *args, **kwargs): + def exclude(self, *args: Any, **kwargs: Any) -> "BaseSet": """ Exclude nodes from the NodeSet via filters. @@ -1376,13 +1416,13 @@ def exclude(self, *args, **kwargs): self.q_filters = Q(self.q_filters & ~Q(*args, **kwargs)) return self - def has(self, **kwargs): + def has(self, **kwargs: Any) -> "BaseSet": must_match, dont_match = process_has_args(self.source_class, kwargs) self.must_match.update(must_match) self.dont_match.update(dont_match) return self - def order_by(self, *props): + def order_by(self, *props: Any) -> "BaseSet": """ Order by properties. Prepend with minus to do descending. Pass None to remove ordering. @@ -1420,7 +1460,7 @@ def _register_relation_to_fetch( relation_def: Any, alias: TOptional[str] = None, include_in_return: bool = True, - ): + ) -> dict: if isinstance(relation_def, Optional): item = {"path": relation_def.relation, "optional": True} else: @@ -1431,7 +1471,7 @@ def _register_relation_to_fetch( item["alias"] = alias return item - def fetch_relations(self, *relation_names): + def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet": """Specify a set of relations to traverse and return.""" relations = [] for relation_name in relation_names: @@ -1439,7 +1479,9 @@ def fetch_relations(self, *relation_names): self.relations_to_fetch = relations return self - def traverse_relations(self, *relation_names, **aliased_relation_names): + def traverse_relations( + self, *relation_names: tuple[str, ...], **aliased_relation_names: dict + ) -> "NodeSet": """Specify a set of relations to traverse only.""" relations = [] for relation_name in relation_names: @@ -1456,10 +1498,13 @@ def traverse_relations(self, *relation_names, **aliased_relation_names): self.relations_to_fetch = relations return self - def annotate(self, *vars, **aliased_vars): + def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "NodeSet": """Annotate node set results with extra variables.""" - def register_extra_var(vardef, varname: Union[str, None] = None): + def register_extra_var( + vardef: Union[AggregatingFunction, ScalarFunction, Any], + varname: Union[str, None] = None, + ) -> None: if isinstance(vardef, (AggregatingFunction, ScalarFunction)): self._extra_results.append( {"vardef": vardef, "alias": varname if varname else ""} @@ -1474,7 +1519,7 @@ def register_extra_var(vardef, varname: Union[str, None] = None): return self - def _to_subgraph(self, root_node, other_nodes, subgraph): + def _to_subgraph(self, root_node: Any, other_nodes: Any, subgraph: dict) -> Any: """Recursive method to build root_node's relation graph from subgraph.""" root_node._relations = {} for name, relation_def in subgraph.items(): @@ -1644,7 +1689,7 @@ class Traversal(BaseSet): name: str filters: list - def __await__(self): + def __await__(self) -> Any: return self.all().__await__() def __init__(self, source: Any, name: str, definition: dict) -> None: @@ -1679,7 +1724,7 @@ def __init__(self, source: Any, name: str, definition: dict) -> None: self.name = name self.filters: list = [] - def match(self, **kwargs): + def match(self, **kwargs: Any) -> "Traversal": """ Traverse relationships with properties matching the given parameters. diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index a4a3cc69..e88346f1 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -2,7 +2,7 @@ import inspect import sys from importlib import import_module -from typing import Any +from typing import Any, Callable from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.sync_.core import db @@ -18,11 +18,11 @@ # check source node is saved and not deleted -def check_source(fn): +def check_source(fn: Callable) -> Callable: fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ @functools.wraps(fn) - def checker(self, *args, **kwargs): + def checker(self: Any, *args: Any, **kwargs: Any) -> Callable: self.source._pre_action_check(self.name + "." + fn_name) return fn(self, *args, **kwargs) @@ -30,7 +30,7 @@ def checker(self, *args, **kwargs): # checks if obj is a direct subclass, 1 level -def is_direct_subclass(obj, classinfo): +def is_direct_subclass(obj: Any, classinfo: Any) -> bool: for base in obj.__bases__: if base == classinfo: return True @@ -56,7 +56,7 @@ def __init__(self, source: Any, key: str, definition: dict): self.name = key self.definition = definition - def __str__(self): + def __str__(self) -> str: direction = "either" if self.definition["direction"] == OUTGOING: direction = "a outgoing" From 804ee44efb52c39565677d80aaa18767e909af82 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 13 Dec 2024 15:38:04 +0100 Subject: [PATCH 16/20] Type hints for relationships --- neomodel/async_/core.py | 6 +- neomodel/async_/relationship.py | 51 ++++++---- neomodel/async_/relationship_manager.py | 128 ++++++++++++++---------- neomodel/sync_/core.py | 6 +- neomodel/sync_/relationship.py | 51 ++++++---- neomodel/sync_/relationship_manager.py | 120 ++++++++++++---------- 6 files changed, 214 insertions(+), 148 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index a8ad8c62..c04bacc9 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -7,7 +7,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Callable, Optional, TextIO, Type, Union +from typing import Any, Callable, Optional, TextIO, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -1358,7 +1358,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable: class NodeMeta(type): - DoesNotExist: Type[DoesNotExist] + DoesNotExist: type[DoesNotExist] __required_properties__: tuple[str, ...] __all_properties__: tuple[tuple[str, Any], ...] __all_aliases__: tuple[tuple[str, Any], ...] @@ -1461,7 +1461,7 @@ def build_class_registry(cls: Any) -> None: ) -NodeBase: Type = NodeMeta( +NodeBase: type = NodeMeta( "NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True} ) diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index 365dd132..ee7ba59b 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -1,4 +1,6 @@ -from typing import Type +from typing import Any, Optional + +from neo4j.graph import Relationship from neomodel.async_.core import adb from neomodel.async_.property_manager import AsyncPropertyManager @@ -9,8 +11,10 @@ class RelationshipMeta(type): - def __new__(mcs, name, bases, dct): - inst = super().__new__(mcs, name, bases, dct) + def __new__( + mcs: type, name: str, bases: tuple[type, ...], dct: dict[str, Any] + ) -> Any: + inst: RelationshipMeta = type.__new__(mcs, name, bases, dct) for key, value in dct.items(): if issubclass(value.__class__, Property): if key == "source" or key == "target": @@ -40,7 +44,7 @@ def __new__(mcs, name, bases, dct): return inst -StructuredRelBase: Type = RelationshipMeta( +StructuredRelBase: type = RelationshipMeta( "RelationshipBase", (AsyncPropertyManager,), {} ) @@ -50,27 +54,30 @@ class AsyncStructuredRel(StructuredRelBase): Base class for relationship objects """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: dict) -> None: super().__init__(*args, **kwargs) @property - def element_id(self): + def element_id(self) -> Optional[Any]: if hasattr(self, "element_id_property"): return self.element_id_property + return None @property - def _start_node_element_id(self): + def _start_node_element_id(self) -> Optional[Any]: if hasattr(self, "_start_node_element_id_property"): return self._start_node_element_id_property + return None @property - def _end_node_element_id(self): + def _end_node_element_id(self) -> Optional[Any]: if hasattr(self, "_end_node_element_id_property"): return self._end_node_element_id_property + return None # Version 4.4 support - id is deprecated in version 5.x @property - def id(self): + def id(self) -> int: try: return int(self.element_id_property) except (TypeError, ValueError) as exc: @@ -78,7 +85,7 @@ def id(self): # Version 4.4 support - id is deprecated in version 5.x @property - def _start_node_id(self): + def _start_node_id(self) -> int: try: return int(self._start_node_element_id_property) except (TypeError, ValueError) as exc: @@ -86,14 +93,14 @@ def _start_node_id(self): # Version 4.4 support - id is deprecated in version 5.x @property - def _end_node_id(self): + def _end_node_id(self) -> int: try: return int(self._end_node_element_id_property) except (TypeError, ValueError) as exc: raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc @hooks - async def save(self): + async def save(self) -> "AsyncStructuredRel": """ Save the relationship @@ -108,7 +115,7 @@ async def save(self): return self - async def start_node(self): + async def start_node(self) -> Any: """ Get start node @@ -127,9 +134,13 @@ async def start_node(self): }, resolve_objects=True, ) + if results is None or results[0] is None or results[0][0] is None: + raise ValueError( + f"Start node with elementId {self._start_node_element_id} not found" + ) return results[0][0][0] - async def end_node(self): + async def end_node(self) -> Any: """ Get end node @@ -148,17 +159,23 @@ async def end_node(self): }, resolve_objects=True, ) + if results is None or results[0] is None or results[0][0] is None: + raise ValueError( + f"Start node with elementId {self._start_node_element_id} not found" + ) return results[0][0][0] @classmethod - def inflate(cls, rel): + def inflate(cls: Any, rel: Relationship) -> "AsyncStructuredRel": """ Inflate a neo4j_driver relationship object to a neomodel object :param rel: :return: StructuredRel """ srel = super().inflate(rel) - srel._start_node_element_id_property = rel.start_node.element_id - srel._end_node_element_id_property = rel.end_node.element_id + if rel.start_node is not None: + srel._start_node_element_id_property = rel.start_node.element_id + if rel.end_node is not None: + srel._end_node_element_id_property = rel.end_node.element_id srel.element_id_property = rel.element_id return srel diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 9bbc5eed..1fe7078e 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -2,7 +2,7 @@ import inspect import sys from importlib import import_module -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Optional, Union from neomodel.async_.core import adb from neomodel.async_.match import ( @@ -21,6 +21,10 @@ get_graph_entity_properties, ) +if TYPE_CHECKING: + from neomodel import AsyncStructuredNode + from neomodel.async_.match import AsyncBaseSet + # check source node is saved and not deleted def check_source(fn: Callable) -> Callable: @@ -70,10 +74,10 @@ def __str__(self) -> str: return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" - def __await__(self): + def __await__(self) -> Any: return self.all().__await__() - def _check_node(self, obj): + def _check_node(self, obj: type["AsyncStructuredNode"]) -> None: """check for valid node i.e correct class and is saved""" if not issubclass(type(obj), self.definition["node_class"]): raise ValueError( @@ -83,7 +87,9 @@ def _check_node(self, obj): raise ValueError("Can't perform operation on unsaved node " + repr(obj)) @check_source - async def connect(self, node, properties=None): + async def connect( + self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + ) -> Optional[AsyncStructuredRel]: """ Connect a node @@ -135,7 +141,7 @@ async def connect(self, node, properties=None): if not rel_model: await self.source.cypher(q, params) - return True + return None results = await self.source.cypher(q + " RETURN r", params) rel_ = results[0][0][0] @@ -147,7 +153,9 @@ async def connect(self, node, properties=None): return rel_instance @check_source - async def replace(self, node, properties=None): + async def replace( + self, node: "AsyncStructuredNode", properties: Optional[dict[str, Any]] = None + ) -> None: """ Disconnect all existing nodes and connect the supplied node @@ -160,7 +168,9 @@ async def replace(self, node, properties=None): await self.connect(node, properties) @check_source - async def relationship(self, node): + async def relationship( + self, node: "AsyncStructuredNode" + ) -> Optional[AsyncStructuredRel]: """ Retrieve the relationship object for this first relationship between self and node. @@ -179,14 +189,16 @@ async def relationship(self, node): ) rels = results[0] if not rels: - return + return None rel_model = self.definition.get("model") or AsyncStructuredRel return self._set_start_end_cls(rel_model.inflate(rels[0][0]), node) @check_source - async def all_relationships(self, node): + async def all_relationships( + self, node: "AsyncStructuredNode" + ) -> list[AsyncStructuredRel]: """ Retrieve all relationship objects between self and node. @@ -209,7 +221,9 @@ async def all_relationships(self, node): self._set_start_end_cls(rel_model.inflate(rel[0]), node) for rel in rels ] - def _set_start_end_cls(self, rel_instance, obj): + def _set_start_end_cls( + self, rel_instance: AsyncStructuredRel, obj: "AsyncStructuredNode" + ) -> AsyncStructuredRel: if self.definition["direction"] == INCOMING: rel_instance._start_node_class = obj.__class__ rel_instance._end_node_class = self.source_class @@ -219,7 +233,9 @@ def _set_start_end_cls(self, rel_instance, obj): return rel_instance @check_source - async def reconnect(self, old_node, new_node): + async def reconnect( + self, old_node: "AsyncStructuredNode", new_node: "AsyncStructuredNode" + ) -> None: """ Disconnect old_node and connect new_node copying over any properties on the original relationship. @@ -270,7 +286,7 @@ async def reconnect(self, old_node, new_node): ) @check_source - async def disconnect(self, node): + async def disconnect(self, node: "AsyncStructuredNode") -> None: """ Disconnect a node @@ -287,7 +303,7 @@ async def disconnect(self, node): ) @check_source - async def disconnect_all(self): + async def disconnect_all(self) -> None: """ Disconnect all nodes @@ -303,11 +319,11 @@ async def disconnect_all(self): await self.source.cypher(q) @check_source - def _new_traversal(self): + def _new_traversal(self) -> AsyncTraversal: return AsyncTraversal(self.source, self.name, self.definition) # The methods below simply proxy the match engine. - def get(self, **kwargs): + def get(self, **kwargs: Any) -> AsyncNodeSet: """ Retrieve a related node with the matching node properties. @@ -316,7 +332,7 @@ def get(self, **kwargs): """ return AsyncNodeSet(self._new_traversal()).get(**kwargs) - def get_or_none(self, **kwargs): + def get_or_none(self, **kwargs: dict) -> AsyncNodeSet: """ Retrieve a related node with the matching node properties or return None. @@ -325,7 +341,7 @@ def get_or_none(self, **kwargs): """ return AsyncNodeSet(self._new_traversal()).get_or_none(**kwargs) - def filter(self, *args, **kwargs): + def filter(self, *args: Any, **kwargs: dict) -> "AsyncBaseSet": """ Retrieve related nodes matching the provided properties. @@ -335,7 +351,7 @@ def filter(self, *args, **kwargs): """ return AsyncNodeSet(self._new_traversal()).filter(*args, **kwargs) - def order_by(self, *props): + def order_by(self, *props: Any) -> "AsyncBaseSet": """ Order related nodes by specified properties @@ -344,7 +360,7 @@ def order_by(self, *props): """ return AsyncNodeSet(self._new_traversal()).order_by(*props) - def exclude(self, *args, **kwargs): + def exclude(self, *args: Any, **kwargs: dict) -> "AsyncBaseSet": """ Exclude nodes that match the provided properties. @@ -354,7 +370,7 @@ def exclude(self, *args, **kwargs): """ return AsyncNodeSet(self._new_traversal()).exclude(*args, **kwargs) - async def is_connected(self, node): + async def is_connected(self, node: "AsyncStructuredNode") -> bool: """ Check if a node is connected with this relationship type :param node: @@ -362,7 +378,7 @@ async def is_connected(self, node): """ return await self._new_traversal().check_contains(node) - async def single(self): + async def single(self) -> Optional["AsyncStructuredNode"]: """ Get a single related node or none. @@ -372,9 +388,9 @@ async def single(self): rels = await self return rels[0] except IndexError: - pass + return None - def match(self, **kwargs): + def match(self, **kwargs: dict) -> AsyncNodeSet: """ Return set of nodes who's relationship properties match supplied args @@ -383,7 +399,7 @@ def match(self, **kwargs): """ return self._new_traversal().match(**kwargs) - async def all(self): + async def all(self) -> list: """ Return all related nodes. @@ -391,34 +407,34 @@ async def all(self): """ return await self._new_traversal().all() - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator: return self._new_traversal().__aiter__() - async def get_len(self): + async def get_len(self) -> int: return await self._new_traversal().get_len() - async def check_bool(self): + async def check_bool(self) -> bool: return await self._new_traversal().check_bool() - async def check_nonzero(self): + async def check_nonzero(self) -> bool: return self._new_traversal().check_nonzero() - async def check_contains(self, obj): + async def check_contains(self, obj: Any) -> bool: return self._new_traversal().check_contains(obj) - async def get_item(self, key): + async def get_item(self, key: Union[int, slice]) -> Any: return self._new_traversal().get_item(key) class AsyncRelationshipDefinition: def __init__( self, - relation_type, - cls_name, - direction, - manager=AsyncRelationshipManager, - model=None, - ): + relation_type: str, + cls_name: str, + direction: int, + manager: type[AsyncRelationshipManager] = AsyncRelationshipManager, + model: Optional[AsyncStructuredRel] = None, + ) -> None: self._validate_class(cls_name, model) current_frame = inspect.currentframe() @@ -469,14 +485,16 @@ def __init__( # If the mapping does not exist then it is simply created. adb._NODE_CLASS_REGISTRY[label_set] = model - def _validate_class(self, cls_name, model): + def _validate_class( + self, cls_name: str, model: Optional[AsyncStructuredRel] = None + ) -> None: if not isinstance(cls_name, (str, object)): raise ValueError("Expected class name or class got " + repr(cls_name)) if model and not issubclass(model, (AsyncStructuredRel,)): raise ValueError("model must be a StructuredRel") - def lookup_node_class(self): + def lookup_node_class(self) -> None: if not isinstance(self._raw_class, str): self.definition["node_class"] = self._raw_class else: @@ -513,7 +531,9 @@ def lookup_node_class(self): module = import_module(namespace).__name__ self.definition["node_class"] = getattr(sys.modules[module], name) - def build_manager(self, source, name): + def build_manager( + self, source: "AsyncStructuredNode", name: str + ) -> AsyncRelationshipManager: self.lookup_node_class() return self.manager(source, name, self.definition) @@ -529,11 +549,11 @@ class AsyncZeroOrMore(AsyncRelationshipManager): class AsyncRelationshipTo(AsyncRelationshipDefinition): def __init__( self, - cls_name, - relation_type, - cardinality=AsyncZeroOrMore, - model=None, - ): + cls_name: str, + relation_type: str, + cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore, + model: Optional[AsyncStructuredRel] = None, + ) -> None: super().__init__( relation_type, cls_name, OUTGOING, manager=cardinality, model=model ) @@ -542,11 +562,11 @@ def __init__( class AsyncRelationshipFrom(AsyncRelationshipDefinition): def __init__( self, - cls_name, - relation_type, - cardinality=AsyncZeroOrMore, - model=None, - ): + cls_name: str, + relation_type: str, + cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore, + model: Optional[AsyncStructuredRel] = None, + ) -> None: super().__init__( relation_type, cls_name, INCOMING, manager=cardinality, model=model ) @@ -555,11 +575,11 @@ def __init__( class AsyncRelationship(AsyncRelationshipDefinition): def __init__( self, - cls_name, - relation_type, - cardinality=AsyncZeroOrMore, - model=None, - ): + cls_name: str, + relation_type: str, + cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore, + model: Optional[AsyncStructuredRel] = None, + ) -> None: super().__init__( relation_type, cls_name, EITHER, manager=cardinality, model=model ) diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index b66a9ab6..17a9c74f 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -7,7 +7,7 @@ from functools import wraps from itertools import combinations from threading import local -from typing import Any, Callable, Optional, TextIO, Type, Union +from typing import Any, Callable, Optional, TextIO, Union from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -1351,7 +1351,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable: class NodeMeta(type): - DoesNotExist: Type[DoesNotExist] + DoesNotExist: type[DoesNotExist] __required_properties__: tuple[str, ...] __all_properties__: tuple[tuple[str, Any], ...] __all_aliases__: tuple[tuple[str, Any], ...] @@ -1454,7 +1454,7 @@ def build_class_registry(cls: Any) -> None: ) -NodeBase: Type = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) +NodeBase: type = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) class StructuredNode(NodeBase): diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 9b3bdf95..b9382f29 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,4 +1,6 @@ -from typing import Type +from typing import Any, Optional + +from neo4j.graph import Relationship from neomodel.hooks import hooks from neomodel.properties import Property @@ -9,8 +11,10 @@ class RelationshipMeta(type): - def __new__(mcs, name, bases, dct): - inst = super().__new__(mcs, name, bases, dct) + def __new__( + mcs: type, name: str, bases: tuple[type, ...], dct: dict[str, Any] + ) -> Any: + inst: RelationshipMeta = type.__new__(mcs, name, bases, dct) for key, value in dct.items(): if issubclass(value.__class__, Property): if key == "source" or key == "target": @@ -40,7 +44,7 @@ def __new__(mcs, name, bases, dct): return inst -StructuredRelBase: Type = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) +StructuredRelBase: type = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) class StructuredRel(StructuredRelBase): @@ -48,27 +52,30 @@ class StructuredRel(StructuredRelBase): Base class for relationship objects """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: dict) -> None: super().__init__(*args, **kwargs) @property - def element_id(self): + def element_id(self) -> Optional[Any]: if hasattr(self, "element_id_property"): return self.element_id_property + return None @property - def _start_node_element_id(self): + def _start_node_element_id(self) -> Optional[Any]: if hasattr(self, "_start_node_element_id_property"): return self._start_node_element_id_property + return None @property - def _end_node_element_id(self): + def _end_node_element_id(self) -> Optional[Any]: if hasattr(self, "_end_node_element_id_property"): return self._end_node_element_id_property + return None # Version 4.4 support - id is deprecated in version 5.x @property - def id(self): + def id(self) -> int: try: return int(self.element_id_property) except (TypeError, ValueError) as exc: @@ -76,7 +83,7 @@ def id(self): # Version 4.4 support - id is deprecated in version 5.x @property - def _start_node_id(self): + def _start_node_id(self) -> int: try: return int(self._start_node_element_id_property) except (TypeError, ValueError) as exc: @@ -84,14 +91,14 @@ def _start_node_id(self): # Version 4.4 support - id is deprecated in version 5.x @property - def _end_node_id(self): + def _end_node_id(self) -> int: try: return int(self._end_node_element_id_property) except (TypeError, ValueError) as exc: raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc @hooks - def save(self): + def save(self) -> "StructuredRel": """ Save the relationship @@ -106,7 +113,7 @@ def save(self): return self - def start_node(self): + def start_node(self) -> Any: """ Get start node @@ -121,9 +128,13 @@ def start_node(self): {"start_node_element_id": db.parse_element_id(self._start_node_element_id)}, resolve_objects=True, ) + if results is None or results[0] is None or results[0][0] is None: + raise ValueError( + f"Start node with elementId {self._start_node_element_id} not found" + ) return results[0][0][0] - def end_node(self): + def end_node(self) -> Any: """ Get end node @@ -138,17 +149,23 @@ def end_node(self): {"end_node_element_id": db.parse_element_id(self._end_node_element_id)}, resolve_objects=True, ) + if results is None or results[0] is None or results[0][0] is None: + raise ValueError( + f"Start node with elementId {self._start_node_element_id} not found" + ) return results[0][0][0] @classmethod - def inflate(cls, rel): + def inflate(cls: Any, rel: Relationship) -> "StructuredRel": """ Inflate a neo4j_driver relationship object to a neomodel object :param rel: :return: StructuredRel """ srel = super().inflate(rel) - srel._start_node_element_id_property = rel.start_node.element_id - srel._end_node_element_id_property = rel.end_node.element_id + if rel.start_node is not None: + srel._start_node_element_id_property = rel.start_node.element_id + if rel.end_node is not None: + srel._end_node_element_id_property = rel.end_node.element_id srel.element_id_property = rel.element_id return srel diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index e88346f1..25067f1c 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -2,7 +2,7 @@ import inspect import sys from importlib import import_module -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Union from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.sync_.core import db @@ -16,6 +16,10 @@ get_graph_entity_properties, ) +if TYPE_CHECKING: + from neomodel import StructuredNode + from neomodel.sync_.match import BaseSet + # check source node is saved and not deleted def check_source(fn: Callable) -> Callable: @@ -65,10 +69,10 @@ def __str__(self) -> str: return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" - def __await__(self): + def __await__(self) -> Any: return self.all().__await__() - def _check_node(self, obj): + def _check_node(self, obj: type["StructuredNode"]) -> None: """check for valid node i.e correct class and is saved""" if not issubclass(type(obj), self.definition["node_class"]): raise ValueError( @@ -78,7 +82,9 @@ def _check_node(self, obj): raise ValueError("Can't perform operation on unsaved node " + repr(obj)) @check_source - def connect(self, node, properties=None): + def connect( + self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + ) -> Optional[StructuredRel]: """ Connect a node @@ -130,7 +136,7 @@ def connect(self, node, properties=None): if not rel_model: self.source.cypher(q, params) - return True + return None results = self.source.cypher(q + " RETURN r", params) rel_ = results[0][0][0] @@ -142,7 +148,9 @@ def connect(self, node, properties=None): return rel_instance @check_source - def replace(self, node, properties=None): + def replace( + self, node: "StructuredNode", properties: Optional[dict[str, Any]] = None + ) -> None: """ Disconnect all existing nodes and connect the supplied node @@ -155,7 +163,7 @@ def replace(self, node, properties=None): self.connect(node, properties) @check_source - def relationship(self, node): + def relationship(self, node: "StructuredNode") -> Optional[StructuredRel]: """ Retrieve the relationship object for this first relationship between self and node. @@ -172,14 +180,14 @@ def relationship(self, node): results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) rels = results[0] if not rels: - return + return None rel_model = self.definition.get("model") or StructuredRel return self._set_start_end_cls(rel_model.inflate(rels[0][0]), node) @check_source - def all_relationships(self, node): + def all_relationships(self, node: "StructuredNode") -> list[StructuredRel]: """ Retrieve all relationship objects between self and node. @@ -200,7 +208,9 @@ def all_relationships(self, node): self._set_start_end_cls(rel_model.inflate(rel[0]), node) for rel in rels ] - def _set_start_end_cls(self, rel_instance, obj): + def _set_start_end_cls( + self, rel_instance: StructuredRel, obj: "StructuredNode" + ) -> StructuredRel: if self.definition["direction"] == INCOMING: rel_instance._start_node_class = obj.__class__ rel_instance._end_node_class = self.source_class @@ -210,7 +220,7 @@ def _set_start_end_cls(self, rel_instance, obj): return rel_instance @check_source - def reconnect(self, old_node, new_node): + def reconnect(self, old_node: "StructuredNode", new_node: "StructuredNode") -> None: """ Disconnect old_node and connect new_node copying over any properties on the original relationship. @@ -259,7 +269,7 @@ def reconnect(self, old_node, new_node): self.source.cypher(q, {"old": old_node_element_id, "new": new_node_element_id}) @check_source - def disconnect(self, node): + def disconnect(self, node: "StructuredNode") -> None: """ Disconnect a node @@ -274,7 +284,7 @@ def disconnect(self, node): self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) @check_source - def disconnect_all(self): + def disconnect_all(self) -> None: """ Disconnect all nodes @@ -286,11 +296,11 @@ def disconnect_all(self): self.source.cypher(q) @check_source - def _new_traversal(self): + def _new_traversal(self) -> Traversal: return Traversal(self.source, self.name, self.definition) # The methods below simply proxy the match engine. - def get(self, **kwargs): + def get(self, **kwargs: Any) -> NodeSet: """ Retrieve a related node with the matching node properties. @@ -299,7 +309,7 @@ def get(self, **kwargs): """ return NodeSet(self._new_traversal()).get(**kwargs) - def get_or_none(self, **kwargs): + def get_or_none(self, **kwargs: dict) -> NodeSet: """ Retrieve a related node with the matching node properties or return None. @@ -308,7 +318,7 @@ def get_or_none(self, **kwargs): """ return NodeSet(self._new_traversal()).get_or_none(**kwargs) - def filter(self, *args, **kwargs): + def filter(self, *args: Any, **kwargs: dict) -> "BaseSet": """ Retrieve related nodes matching the provided properties. @@ -318,7 +328,7 @@ def filter(self, *args, **kwargs): """ return NodeSet(self._new_traversal()).filter(*args, **kwargs) - def order_by(self, *props): + def order_by(self, *props: Any) -> "BaseSet": """ Order related nodes by specified properties @@ -327,7 +337,7 @@ def order_by(self, *props): """ return NodeSet(self._new_traversal()).order_by(*props) - def exclude(self, *args, **kwargs): + def exclude(self, *args: Any, **kwargs: dict) -> "BaseSet": """ Exclude nodes that match the provided properties. @@ -337,7 +347,7 @@ def exclude(self, *args, **kwargs): """ return NodeSet(self._new_traversal()).exclude(*args, **kwargs) - def is_connected(self, node): + def is_connected(self, node: "StructuredNode") -> bool: """ Check if a node is connected with this relationship type :param node: @@ -345,7 +355,7 @@ def is_connected(self, node): """ return self._new_traversal().__contains__(node) - def single(self): + def single(self) -> Optional["StructuredNode"]: """ Get a single related node or none. @@ -355,9 +365,9 @@ def single(self): rels = self return rels[0] except IndexError: - pass + return None - def match(self, **kwargs): + def match(self, **kwargs: dict) -> NodeSet: """ Return set of nodes who's relationship properties match supplied args @@ -366,7 +376,7 @@ def match(self, **kwargs): """ return self._new_traversal().match(**kwargs) - def all(self): + def all(self) -> list: """ Return all related nodes. @@ -374,34 +384,34 @@ def all(self): """ return self._new_traversal().all() - def __iter__(self): + def __iter__(self) -> Iterator: return self._new_traversal().__iter__() - def __len__(self): + def __len__(self) -> int: return self._new_traversal().__len__() - def __bool__(self): + def __bool__(self) -> bool: return self._new_traversal().__bool__() - def __nonzero__(self): + def __nonzero__(self) -> bool: return self._new_traversal().__nonzero__() - def __contains__(self, obj): + def __contains__(self, obj: Any) -> bool: return self._new_traversal().__contains__(obj) - def __getitem__(self, key): + def __getitem__(self, key: Union[int, slice]) -> Any: return self._new_traversal().__getitem__(key) class RelationshipDefinition: def __init__( self, - relation_type, - cls_name, - direction, - manager=RelationshipManager, - model=None, - ): + relation_type: str, + cls_name: str, + direction: int, + manager: type[RelationshipManager] = RelationshipManager, + model: Optional[StructuredRel] = None, + ) -> None: self._validate_class(cls_name, model) current_frame = inspect.currentframe() @@ -452,14 +462,16 @@ def __init__( # If the mapping does not exist then it is simply created. db._NODE_CLASS_REGISTRY[label_set] = model - def _validate_class(self, cls_name, model): + def _validate_class( + self, cls_name: str, model: Optional[StructuredRel] = None + ) -> None: if not isinstance(cls_name, (str, object)): raise ValueError("Expected class name or class got " + repr(cls_name)) if model and not issubclass(model, (StructuredRel,)): raise ValueError("model must be a StructuredRel") - def lookup_node_class(self): + def lookup_node_class(self) -> None: if not isinstance(self._raw_class, str): self.definition["node_class"] = self._raw_class else: @@ -496,7 +508,7 @@ def lookup_node_class(self): module = import_module(namespace).__name__ self.definition["node_class"] = getattr(sys.modules[module], name) - def build_manager(self, source, name): + def build_manager(self, source: "StructuredNode", name: str) -> RelationshipManager: self.lookup_node_class() return self.manager(source, name, self.definition) @@ -512,11 +524,11 @@ class ZeroOrMore(RelationshipManager): class RelationshipTo(RelationshipDefinition): def __init__( self, - cls_name, - relation_type, - cardinality=ZeroOrMore, - model=None, - ): + cls_name: str, + relation_type: str, + cardinality: type[RelationshipManager] = ZeroOrMore, + model: Optional[StructuredRel] = None, + ) -> None: super().__init__( relation_type, cls_name, OUTGOING, manager=cardinality, model=model ) @@ -525,11 +537,11 @@ def __init__( class RelationshipFrom(RelationshipDefinition): def __init__( self, - cls_name, - relation_type, - cardinality=ZeroOrMore, - model=None, - ): + cls_name: str, + relation_type: str, + cardinality: type[RelationshipManager] = ZeroOrMore, + model: Optional[StructuredRel] = None, + ) -> None: super().__init__( relation_type, cls_name, INCOMING, manager=cardinality, model=model ) @@ -538,11 +550,11 @@ def __init__( class Relationship(RelationshipDefinition): def __init__( self, - cls_name, - relation_type, - cardinality=ZeroOrMore, - model=None, - ): + cls_name: str, + relation_type: str, + cardinality: type[RelationshipManager] = ZeroOrMore, + model: Optional[StructuredRel] = None, + ) -> None: super().__init__( relation_type, cls_name, EITHER, manager=cardinality, model=model ) From 72d8029aee87ff47534d2449c3d2764574b454ba Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 13 Dec 2024 15:56:32 +0100 Subject: [PATCH 17/20] Fix parallel runtime test --- test/async_/test_match_api.py | 2 +- test/sync_/test_match_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 70c7f351..d9d29b60 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -1174,7 +1174,7 @@ async def test_async_iterator(): def assert_last_query_startswith(mock_func, query) -> bool: - return mock_func.call_args_list[-1].args[0].startswith(query) + return mock_func.call_args_list[-1].kwargs["query"].startswith(query) @mark_async_test diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 94465db2..15778635 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1158,7 +1158,7 @@ def test_async_iterator(): def assert_last_query_startswith(mock_func, query) -> bool: - return mock_func.call_args_list[-1].args[0].startswith(query) + return mock_func.call_args_list[-1].kwargs["query"].startswith(query) @mark_sync_test From 3c1e043af41cad56e4229260dfdc49004b6c2751 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 13 Dec 2024 16:07:02 +0100 Subject: [PATCH 18/20] AsyncRelationshipManager.get and .get_or_none should be async #847 --- neomodel/async_/relationship_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 1fe7078e..7c7ea350 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -323,23 +323,23 @@ def _new_traversal(self) -> AsyncTraversal: return AsyncTraversal(self.source, self.name, self.definition) # The methods below simply proxy the match engine. - def get(self, **kwargs: Any) -> AsyncNodeSet: + async def get(self, **kwargs: Any) -> AsyncNodeSet: """ Retrieve a related node with the matching node properties. :param kwargs: same syntax as `NodeSet.filter()` :return: node """ - return AsyncNodeSet(self._new_traversal()).get(**kwargs) + return await AsyncNodeSet(self._new_traversal()).get(**kwargs) - def get_or_none(self, **kwargs: dict) -> AsyncNodeSet: + async def get_or_none(self, **kwargs: dict) -> AsyncNodeSet: """ Retrieve a related node with the matching node properties or return None. :param kwargs: same syntax as `NodeSet.filter()` :return: node """ - return AsyncNodeSet(self._new_traversal()).get_or_none(**kwargs) + return await AsyncNodeSet(self._new_traversal()).get_or_none(**kwargs) def filter(self, *args: Any, **kwargs: dict) -> "AsyncBaseSet": """ From 316a9fd9c3d9cb2ace6ee4e26d6212409e5958fc Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 13 Dec 2024 16:15:07 +0100 Subject: [PATCH 19/20] mypy ignore await method --- neomodel/async_/match.py | 4 ++-- neomodel/async_/relationship_manager.py | 2 +- neomodel/sync_/match.py | 4 ++-- neomodel/sync_/relationship_manager.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 570f63e6..e46b37d0 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -1306,7 +1306,7 @@ def __init__(self, source: Any) -> None: self._intermediate_transforms: list = [] def __await__(self) -> Any: - return self.all().__await__() + return self.all().__await__() # type: ignore[attr-defined] async def _get( self, limit: TOptional[int] = None, lazy: bool = False, **kwargs: dict[str, Any] @@ -1696,7 +1696,7 @@ class AsyncTraversal(AsyncBaseSet): filters: list def __await__(self) -> Any: - return self.all().__await__() + return self.all().__await__() # type: ignore[attr-defined] def __init__(self, source: Any, name: str, definition: dict) -> None: """ diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 7c7ea350..d9227ba0 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -75,7 +75,7 @@ def __str__(self) -> str: return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" def __await__(self) -> Any: - return self.all().__await__() + return self.all().__await__() # type: ignore[attr-defined] def _check_node(self, obj: type["AsyncStructuredNode"]) -> None: """check for valid node i.e correct class and is saved""" diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index fe7dc33b..1f5d7087 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1302,7 +1302,7 @@ def __init__(self, source: Any) -> None: self._intermediate_transforms: list = [] def __await__(self) -> Any: - return self.all().__await__() + return self.all().__await__() # type: ignore[attr-defined] def _get( self, limit: TOptional[int] = None, lazy: bool = False, **kwargs: dict[str, Any] @@ -1690,7 +1690,7 @@ class Traversal(BaseSet): filters: list def __await__(self) -> Any: - return self.all().__await__() + return self.all().__await__() # type: ignore[attr-defined] def __init__(self, source: Any, name: str, definition: dict) -> None: """ diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 25067f1c..cba9fae8 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -70,7 +70,7 @@ def __str__(self) -> str: return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" def __await__(self) -> Any: - return self.all().__await__() + return self.all().__await__() # type: ignore[attr-defined] def _check_node(self, obj: type["StructuredNode"]) -> None: """check for valid node i.e correct class and is saved""" From 3c7c701cbfb3a40d1513618a2ab1d9e8e80155ea Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 13 Dec 2024 16:36:24 +0100 Subject: [PATCH 20/20] Fix changelog --- Changelog | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Changelog b/Changelog index a93e806e..5fa85e80 100644 --- a/Changelog +++ b/Changelog @@ -1,7 +1,8 @@ -Vesion 5.4.2 2024-12 +Version 5.4.2 2024-12 * Add support for Neo4j Rust driver extension : pip install neomodel[rust-driver-ext] * Add initial_context parameter to subqueries * NodeNameResolver can call self to reference top-level node +* Housekeeping : implementing mypy for static typing Version 5.4.1 2024-11 * Add support for Cypher parallel runtime