From b6c9d48c63022547adf9c4f5272071e8742dacef Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 23 Nov 2023 08:57:50 +0100 Subject: [PATCH 01/73] Propagating async from AsyncDriver ; refactoring Database --- neomodel/__init__.py | 40 +- neomodel/_async/__init__.py | 1 + neomodel/_async/core.py | 1505 +++++++++++++++++ neomodel/config.py | 2 +- neomodel/contrib/__init__.py | 2 +- neomodel/contrib/semi_structured.py | 6 +- neomodel/core.py | 784 --------- neomodel/integration/numpy.py | 2 +- neomodel/integration/pandas.py | 2 +- neomodel/match.py | 36 +- neomodel/path.py | 23 +- neomodel/properties.py | 4 +- neomodel/relationship.py | 24 +- neomodel/relationship_manager.py | 32 +- neomodel/scripts/neomodel_inspect_database.py | 32 +- neomodel/scripts/neomodel_install_labels.py | 6 +- neomodel/scripts/neomodel_remove_labels.py | 6 +- neomodel/util.py | 626 ------- pyproject.toml | 1 + test/async_/conftest.py | 48 + test/conftest.py | 47 - test/test_alias.py | 8 +- test/test_batch.py | 38 +- test/test_cardinality.py | 42 +- test/test_connection.py | 51 +- test/test_contrib/test_semi_structured.py | 6 +- test/test_contrib/test_spatial_properties.py | 25 +- test/test_cypher.py | 76 +- test/test_database_management.py | 47 +- test/test_dbms_awareness.py | 20 +- test/test_driver_options.py | 34 +- test/test_exceptions.py | 4 +- test/test_hooks.py | 8 +- test/test_indexing.py | 31 +- test/test_issue112.py | 8 +- test/test_issue283.py | 150 +- test/test_issue600.py | 28 +- test/test_label_drop.py | 20 +- test/test_label_install.py | 88 +- test/test_match_api.py | 110 +- test/test_migration_neo4j_5.py | 10 +- test/test_models.py | 117 +- test/test_multiprocessing.py | 8 +- test/test_paths.py | 72 +- test/test_properties.py | 84 +- test/test_relationship_models.py | 32 +- test/test_relationships.py | 58 +- test/test_relative_relationships.py | 11 +- test/test_scripts.py | 38 +- test/test_transactions.py | 96 +- 50 files changed, 2341 insertions(+), 2208 deletions(-) create mode 100644 neomodel/_async/__init__.py create mode 100644 neomodel/_async/core.py delete mode 100644 neomodel/core.py create mode 100644 test/async_/conftest.py diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 23e0142a..013104f6 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,21 +1,21 @@ # pep8: noqa -import pkg_resources - +# TODO : Check imports here +from neomodel._async.core import ( + StructuredNodeAsync, + change_neo4j_password_async, + clear_neo4j_database_async, + drop_constraints_async, + drop_indexes_async, + install_all_labels_async, + install_labels_async, + remove_all_labels_async, +) +from neomodel.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne from neomodel.exceptions import * from neomodel.match import EITHER, INCOMING, OUTGOING, NodeSet, Traversal from neomodel.match_q import Q # noqa -from neomodel.relationship_manager import ( - NotConnected, - Relationship, - RelationshipDefinition, - RelationshipFrom, - RelationshipManager, - RelationshipTo, -) - -from .cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne -from .core import * -from .properties import ( +from neomodel.path import NeomodelPath +from neomodel.properties import ( AliasProperty, ArrayProperty, BooleanProperty, @@ -31,9 +31,15 @@ StringProperty, UniqueIdProperty, ) -from .relationship import StructuredRel -from .util import change_neo4j_password, clear_neo4j_database -from .path import NeomodelPath +from neomodel.relationship import StructuredRel +from neomodel.relationship_manager import ( + NotConnected, + Relationship, + RelationshipDefinition, + RelationshipFrom, + RelationshipManager, + RelationshipTo, +) __author__ = "Robin Edwards" __email__ = "robin.ge@gmail.com" diff --git a/neomodel/_async/__init__.py b/neomodel/_async/__init__.py new file mode 100644 index 00000000..95bbd58a --- /dev/null +++ b/neomodel/_async/__init__.py @@ -0,0 +1 @@ +# from neomodel._async.core import adb diff --git a/neomodel/_async/core.py b/neomodel/_async/core.py new file mode 100644 index 00000000..1df0e771 --- /dev/null +++ b/neomodel/_async/core.py @@ -0,0 +1,1505 @@ +import logging +import os +import sys +import time +import warnings +from itertools import combinations +from threading import local +from typing import Optional, Sequence, Tuple +from urllib.parse import quote, unquote, urlparse + +from neo4j import ( + DEFAULT_DATABASE, + AsyncDriver, + AsyncGraphDatabase, + AsyncResult, + AsyncSession, + AsyncTransaction, + basic_auth, +) +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired +from neo4j.graph import Node, Path, Relationship + +from neomodel import config +from neomodel.exceptions import ( + ConstraintValidationFailed, + DoesNotExist, + FeatureNotSupported, + NodeClassAlreadyDefined, + NodeClassNotDefined, + RelationshipClassNotDefined, + UniqueProperty, +) +from neomodel.hooks import hooks +from neomodel.properties import Property, PropertyManager +from neomodel.util import ( + _get_node_properties, + _UnsavedNode, + classproperty, + deprecated, + version_tag_to_integer, +) + +logger = logging.getLogger(__name__) + +RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" +INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" +CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" +STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" + + +# make sure the connection url has been set prior to executing the wrapped function +def ensure_connection(func): + async def wrapper(self, *args, **kwargs): + # Sort out where to find url + if hasattr(self, "db"): + _db = self.db + else: + _db = self + + if not _db.driver: + if hasattr(config, "DRIVER") and config.DRIVER: + await _db.set_connection_async(driver=config.DRIVER) + elif config.DATABASE_URL: + await _db.set_connection_async(url=config.DATABASE_URL) + + return func(self, *args, **kwargs) + + return wrapper + + +class AsyncDatabase(local): + """ + A singleton object via which all operations from neomodel to the Neo4j backend are handled with. + """ + + _NODE_CLASS_REGISTRY = {} + + def __init__(self): + self._active_transaction = None + self.url = None + self.driver = None + self._session = None + self._pid = None + self._database_name = DEFAULT_DATABASE + self.protocol_version = None + self._database_version = None + self._database_edition = None + self.impersonated_user = None + + async def set_connection_async(self, url: str = None, driver: AsyncDriver = None): + """ + Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. + + Args: + url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname. + When provided, a Neo4j driver instance will be created by neomodel. + + driver (neo4j.Driver): Optionally, a pre-created driver instance. + When provided, neomodel will not create a driver instance but use this one instead. + """ + if driver: + self.driver = driver + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + elif url: + self._parse_driver_from_url(url=url) + + self._pid = os.getpid() + self._active_transaction = None + # Set to default database if it hasn't been set before + if self._database_name is None: + self._database_name = DEFAULT_DATABASE + + # Getting the information about the database version requires a connection to the database + self._database_version = None + self._database_edition = None + self._update_database_version_async() + + def _parse_driver_from_url(self, url: str) -> None: + """Parse the driver information from the given URL and initialize the driver. + + Args: + url (str): The URL to parse. + + Raises: + ValueError: If the URL format is not as expected. + + Returns: + None - Sets the driver and database_name as class properties + """ + p_start = url.replace(":", "", 1).find(":") + 2 + p_end = url.rfind("@") + password = url[p_start:p_end] + url = url.replace(password, quote(password)) + parsed_url = urlparse(url) + + valid_schemas = [ + "bolt", + "bolt+s", + "bolt+ssc", + "bolt+routing", + "neo4j", + "neo4j+s", + "neo4j+ssc", + ] + + if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas: + credentials, hostname = parsed_url.netloc.rsplit("@", 1) + username, password = credentials.split(":") + password = unquote(password) + database_name = parsed_url.path.strip("/") + else: + raise ValueError( + f"Expecting url format: bolt://user:password@localhost:7687 got {url}" + ) + + options = { + "auth": basic_auth(username, password), + "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, + "connection_timeout": config.CONNECTION_TIMEOUT, + "keep_alive": config.KEEP_ALIVE, + "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, + "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, + "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, + "resolver": config.RESOLVER, + "user_agent": config.USER_AGENT, + } + + if "+s" not in parsed_url.scheme: + options["encrypted"] = config.ENCRYPTED + options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + + self.driver = AsyncGraphDatabase.driver( + parsed_url.scheme + "://" + hostname, **options + ) + self.url = url + # The database name can be provided through the url or the config + if database_name == "": + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + else: + self._database_name = database_name + + async def close_connection_async(self): + """ + Closes the currently open driver. + The driver should always be closed at the end of the application's lifecyle. + """ + self._database_version = None + self._database_edition = None + self._database_name = None + await self.driver.close() + self.driver = None + + @property + def database_version(self): + if self._database_version is None: + self._update_database_version_async() + + return self._database_version + + @property + def database_edition(self): + if self._database_edition is None: + self._update_database_version_async() + + return self._database_edition + + @property + def transaction(self): + """ + Returns the current transaction object + """ + return TransactionProxyAsync(self) + + @property + def write_transaction(self): + return TransactionProxyAsync(self, access_mode="WRITE") + + @property + def read_transaction(self): + return TransactionProxyAsync(self, access_mode="READ") + + def impersonate(self, user: str) -> "ImpersonationHandler": + """All queries executed within this context manager will be executed as impersonated user + + Args: + user (str): User to impersonate + + Returns: + ImpersonationHandler: Context manager to set/unset the user to impersonate + """ + if self.database_edition != "enterprise": + raise FeatureNotSupported( + "Impersonation is only available in Neo4j Enterprise edition" + ) + return ImpersonationHandler(self, impersonated_user=user) + + @ensure_connection + async def begin_async(self, access_mode=None, **parameters): + """ + Begins a new transaction. Raises SystemError if a transaction is already active. + """ + if ( + hasattr(self, "_active_transaction") + and self._active_transaction is not None + ): + raise SystemError("Transaction in progress") + self._session: AsyncSession = await self.driver.session( + default_access_mode=access_mode, + database=self._database_name, + impersonated_user=self.impersonated_user, + **parameters, + ) + self._active_transaction: AsyncTransaction = ( + await self._session.begin_transaction() + ) + + @ensure_connection + async def commit_async(self): + """ + Commits the current transaction and closes its session + + :return: last_bookmarks + """ + try: + await self._active_transaction.commit() + last_bookmarks: Bookmarks = await self._session.last_bookmarks() + finally: + # In case when something went wrong during + # committing changes to the database + # we have to close an active transaction and session. + await self._active_transaction.close() + await self._session.close() + self._active_transaction = None + self._session = None + + return last_bookmarks + + @ensure_connection + async def rollback_async(self): + """ + Rolls back the current transaction and closes its session + """ + try: + await self._active_transaction.rollback() + finally: + # In case when something went wrong during changes rollback, + # we have to close an active transaction and session + await self._active_transaction.close() + await self._session.close() + self._active_transaction = None + self._session = None + + async def _update_database_version_async(self): + """ + Updates the database server information when it is required + """ + try: + results = await self.cypher_query_async( + "CALL dbms.components() yield versions, edition return versions[0], edition" + ) + self._database_version = results[0][0][0] + self._database_edition = results[0][0][1] + except ServiceUnavailable: + # The database server is not running yet + pass + + def _object_resolution(self, object_to_resolve): + """ + Performs in place automatic object resolution on a result + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures and Path objects. Not meant to be called + directly, used primarily by _result_resolution. + + :param object_to_resolve: A result as returned by cypher_query. + :type Any: + + :return: An instantiated object. + """ + # Below is the original comment that came with the code extracted in + # this method. It is not very clear but I decided to keep it just in + # case + # + # + # For some reason, while the type of `a_result_attribute[1]` + # as reported by the neo4j driver is `Node` for Node-type data + # retrieved from the database. + # When the retrieved data are Relationship-Type, + # the returned type is `abc.[REL_LABEL]` which is however + # a descendant of Relationship. + # Consequently, the type checking was changed for both + # Node, Relationship objects + if isinstance(object_to_resolve, Node): + return self._NODE_CLASS_REGISTRY[ + frozenset(object_to_resolve.labels) + ].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Relationship): + rel_type = frozenset([object_to_resolve.type]) + return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Path): + from neomodel.path import NeomodelPath + + return NeomodelPath(object_to_resolve) + + if isinstance(object_to_resolve, list): + return self._result_resolution([object_to_resolve]) + + return object_to_resolve + + def _result_resolution(self, result_list): + """ + Performs in place automatic object resolution on a set of results + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures. Not meant to be called directly, + used primarily by cypher_query. + + :param result_list: A list of results as returned by cypher_query. + :type list: + + :return: A list of instantiated objects. + """ + + # Object resolution occurs in-place + for a_result_item in enumerate(result_list): + for a_result_attribute in enumerate(a_result_item[1]): + try: + # Primitive types should remain primitive types, + # Nodes to be resolved to native objects + resolved_object = a_result_attribute[1] + + resolved_object = self._object_resolution(resolved_object) + + result_list[a_result_item[0]][ + a_result_attribute[0] + ] = resolved_object + + except KeyError as exc: + # Not being able to match the label set of a node with a known object results + # in a KeyError in the internal dictionary used for resolution. If it is impossible + # to match, then raise an exception with more details about the error. + if isinstance(a_result_attribute[1], Node): + raise NodeClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + if isinstance(a_result_attribute[1], Relationship): + raise RelationshipClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + return result_list + + @ensure_connection + async def cypher_query_async( + self, + query, + params=None, + handle_unique=True, + retry_on_session_expire=False, + resolve_objects=False, + ) -> (list[list], Tuple[str, ...]): + """ + Runs a query on the database and returns a list of results and their headers. + + :param query: A CYPHER query + :type: str + :param params: Dictionary of parameters + :type: dict + :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors + :type: bool + :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired. + If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance. + :type: bool + :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically + :type: bool + + :return: A tuple containing a list of results and a tuple of headers. + """ + + if self._active_transaction: + # Use current session is a transaction is currently active + results, meta = await self._run_cypher_query_async( + self._active_transaction, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + else: + # Otherwise create a new session in a with to dispose of it after it has been run + with await self.driver.session( + database=self._database_name, impersonated_user=self.impersonated_user + ) as session: + results, meta = await self._run_cypher_query_async( + session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + + return results, meta + + async def _run_cypher_query_async( + self, + session: AsyncSession, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) -> (list[list], Tuple[str, ...]): + try: + # Retrieve the data + start = time.time() + response: AsyncResult = await session.run(query, params) + results, meta = [list(r.values()) for r in response], response.keys() + end = time.time() + + if resolve_objects: + # Do any automatic resolution required + results = self._result_resolution(results) + + except ClientError as e: + if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": + if "already exists with label" in e.message and handle_unique: + raise UniqueProperty(e.message) from e + + raise ConstraintValidationFailed(e.message) from e + exc_info = sys.exc_info() + raise exc_info[1].with_traceback(exc_info[2]) + except SessionExpired: + if retry_on_session_expire: + await self.set_connection_async(url=self.url) + return await self.cypher_query_async( + query=query, + params=params, + handle_unique=handle_unique, + retry_on_session_expire=False, + ) + raise + + tte = end - start + if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float( + os.environ.get("NEOMODEL_SLOW_QUERIES", 0) + ): + logger.debug( + "query: " + + query + + "\nparams: " + + repr(params) + + f"\ntook: {tte:.2g}s\n" + ) + + return results, meta + + def get_id_method(self) -> str: + if self.database_version.startswith("4"): + return "id" + else: + return "elementId" + + async def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict]: + """Returns all indexes existing in the database + + Arguments: + exclude_token_lookup[bool]: Exclude automatically create token lookup indexes + + Returns: + Sequence[dict]: List of dictionaries, each entry being an index definition + """ + indexes, meta_indexes = await self.cypher_query_async("SHOW INDEXES") + indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] + + if exclude_token_lookup: + indexes_as_dict = [ + obj for obj in indexes_as_dict if obj["type"] != "LOOKUP" + ] + + return indexes_as_dict + + async def list_constraints_async(self) -> Sequence[dict]: + """Returns all constraints existing in the database + + Returns: + Sequence[dict]: List of dictionaries, each entry being a constraint definition + """ + constraints, meta_constraints = await self.cypher_query_async( + "SHOW CONSTRAINTS" + ) + constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] + + return constraints_as_dict + + def version_is_higher_than(self, version_tag: str) -> bool: + """Returns true if the database version is higher or equal to a given tag + + Args: + version_tag (str): The version to compare against + + Returns: + bool: True if the database version is higher or equal to the given version + """ + return version_tag_to_integer(self.database_version) >= version_tag_to_integer( + version_tag + ) + + def edition_is_enterprise(self) -> bool: + """Returns true if the database edition is enterprise + + Returns: + bool: True if the database edition is enterprise + """ + return self.database_edition == "enterprise" + + async def change_neo4j_password_async(self, user, new_password): + await self.cypher_query_async( + f"ALTER USER {user} SET PASSWORD '{new_password}'" + ) + + async def clear_neo4j_database_async( + self, clear_constraints=False, clear_indexes=False + ): + await self.cypher_query_async( + """ + MATCH (a) + CALL { WITH a DETACH DELETE a } + IN TRANSACTIONS OF 5000 rows + """ + ) + if clear_constraints: + await drop_constraints_async() + if clear_indexes: + await drop_indexes_async() + + async def drop_constraints_async(self, quiet=True, stdout=None): + """ + Discover and drop all constraints. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + results, meta = await self.cypher_query_async("SHOW CONSTRAINTS") + + results_as_dict = [dict(zip(meta, row)) for row in results] + for constraint in results_as_dict: + await self.cypher_query_async("DROP CONSTRAINT " + constraint["name"]) + if not quiet: + stdout.write( + ( + " - Dropping unique constraint and index" + f" on label {constraint['labelsOrTypes'][0]}" + f" with property {constraint['properties'][0]}.\n" + ) + ) + if not quiet: + stdout.write("\n") + + async def drop_indexes_async(self, quiet=True, stdout=None): + """ + Discover and drop all indexes, except the automatically created token lookup indexes. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + indexes = await self.list_indexes_async(exclude_token_lookup=True) + for index in indexes: + await self.cypher_query_async("DROP INDEX " + index["name"]) + if not quiet: + stdout.write( + f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' + ) + if not quiet: + stdout.write("\n") + + async def remove_all_labels_async(self, stdout=None): + """ + Calls functions for dropping constraints and indexes. + + :param stdout: output stream + :return: None + """ + + if not stdout: + stdout = sys.stdout + + stdout.write("Dropping constraints...\n") + await self.drop_constraints_async(quiet=False, stdout=stdout) + + stdout.write("Dropping indexes...\n") + await self.drop_indexes_async(quiet=False, stdout=stdout) + + async def install_all_labels_async(self, stdout=None): + """ + Discover all subclasses of StructuredNode in your application and execute install_labels on each. + Note: code must be loaded (imported) in order for a class to be discovered. + + :param stdout: output stream + :return: None + """ + + if not stdout or stdout is None: + stdout = sys.stdout + + def subsub(cls): # recursively return all subclasses + subclasses = cls.__subclasses__() + if not subclasses: # base case: no more subclasses + return [] + return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] + + stdout.write("Setting up indexes and constraints...\n\n") + + i = 0 + for cls in subsub(StructuredNodeAsync): + stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") + await install_labels_async(cls, quiet=False, stdout=stdout) + i += 1 + + if i: + stdout.write("\n") + + stdout.write(f"Finished {i} classes.\n") + + async def install_labels_async(self, cls, quiet=True, stdout=None): + """ + Setup labels with indexes and constraints for a given class + + :param cls: StructuredNode class + :type: class + :param quiet: (default true) enable standard output + :param stdout: stdout stream + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + if not hasattr(cls, "__label__"): + if not quiet: + stdout.write( + f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" + ) + return + + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + await self._install_node_async(cls, name, property, quiet, stdout) + + for _, relationship in cls.defined_properties( + aliases=False, rels=True, properties=False + ).items(): + await self._install_relationship_async(cls, relationship, quiet, stdout) + + async def _create_node_index_async(self, label: str, property_name: str, stdout): + try: + await self.cypher_query_async( + f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + async def _create_node_constraint_async( + self, label: str, property_name: str, stdout + ): + try: + await self.cypher_query_async( + f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} + FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + async def _create_relationship_index_async( + self, relationship_type: str, property_name: str, stdout + ): + try: + await self.cypher_query_async( + f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + async def _create_relationship_constraint_async( + self, relationship_type: str, property_name: str, stdout + ): + if self.version_is_higher_than("5.7"): + try: + await self.cypher_query_async( + f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} + FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + else: + raise FeatureNotSupported( + f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." + ) + + async def _install_node_async(self, cls, name, property, quiet, stdout): + # Create indexes and constraints for node property + db_property = property.db_property or name + if property.index: + if not quiet: + stdout.write( + f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + await self._create_node_index_async( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + await self._create_node_constraint_async( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + async def _install_relationship_async(self, cls, relationship, quiet, stdout): + # Create indexes and constraints for relationship property + relationship_cls = relationship.definition["model"] + if relationship_cls is not None: + relationship_type = relationship.definition["relation_type"] + for prop_name, property in relationship_cls.defined_properties( + aliases=False, rels=False + ).items(): + db_property = property.db_property or prop_name + if property.index: + if not quiet: + stdout.write( + f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + await self._create_relationship_index_async( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + await self._create_relationship_constraint_async( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + + +# Create a singleton instance of the database object +adb = AsyncDatabase() + + +# Deprecated methods +async def change_neo4j_password_async(db, user, new_password): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.change_neo4j_password_async(user, new_password) instead. + This direct call will be removed in an upcoming version. + """ + ) + await db.change_neo4j_password_async(user, new_password) + + +async def clear_neo4j_database_async(db, clear_constraints=False, clear_indexes=False): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.clear_neo4j_database_async(clear_constraints, clear_indexes) instead. + This direct call will be removed in an upcoming version. + """ + ) + await db.clear_neo4j_database_async(clear_constraints, clear_indexes) + + +async def drop_constraints_async(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.drop_constraints_async(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.drop_constraints_async(quiet, stdout) + + +async def drop_indexes_async(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.drop_indexes_async(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.drop_indexes_async(quiet, stdout) + + +async def remove_all_labels_async(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.remove_all_labels_async(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.remove_all_labels_async(stdout) + + +async def install_labels_async(cls, quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.install_labels_async(cls, quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.install_labels_async(cls, quiet, stdout) + + +async def install_all_labels_async(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.install_all_labels_async(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.install_all_labels_async(stdout) + + +class TransactionProxyAsync: + bookmarks: Optional[Bookmarks] = None + + def __init__(self, db: AsyncDatabase, access_mode=None): + self.db = db + self.access_mode = access_mode + + @ensure_connection + async def __enter__(self): + await self.db.begin_async( + access_mode=self.access_mode, bookmarks=self.bookmarks + ) + self.bookmarks = None + return self + + async def __exit__(self, exc_type, exc_value, traceback): + if exc_value: + await self.db.rollback_async() + + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) + + if not exc_value: + self.last_bookmark = await self.db.commit_async() + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + @property + def with_bookmark(self): + return BookmarkingTransactionProxyAsync(self.db, self.access_mode) + + +class ImpersonationHandler: + def __init__(self, db: AsyncDatabase, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self): + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + +class BookmarkingTransactionProxyAsync(TransactionProxyAsync): + def __call__(self, func): + def wrapper(*args, **kwargs): + self.bookmarks = kwargs.pop("bookmarks", None) + + with self: + result = func(*args, **kwargs) + self.last_bookmark = None + + return result, self.last_bookmark + + return wrapper + + +# TODO : Either deprecate auto_install_labels +# Or make it work with async +class NodeMeta(type): + def __new__(mcs, name, bases, namespace): + namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) + cls = super().__new__(mcs, name, bases, namespace) + cls.DoesNotExist._model_class = cls + + if hasattr(cls, "__abstract_node__"): + delattr(cls, "__abstract_node__") + else: + if "deleted" in namespace: + raise ValueError( + "Property name 'deleted' is not allowed as it conflicts with neomodel internals." + ) + elif "id" in namespace: + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif "element_id" in namespace: + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + for key, value in ( + (x, y) for x, y in namespace.items() if isinstance(y, Property) + ): + value.name, value.owner = key, cls + if hasattr(value, "setup") and callable(value.setup): + value.setup() + + # cache various groups of properies + cls.__required_properties__ = tuple( + name + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + if property.required or property.unique_index + ) + cls.__all_properties__ = tuple( + cls.defined_properties(aliases=False, rels=False).items() + ) + cls.__all_aliases__ = tuple( + cls.defined_properties(properties=False, rels=False).items() + ) + cls.__all_relationships__ = tuple( + cls.defined_properties(aliases=False, properties=False).items() + ) + + cls.__label__ = namespace.get("__label__", name) + cls.__optional_labels__ = namespace.get("__optional_labels__", []) + + # TODO : See previous TODO comment + # if config.AUTO_INSTALL_LABELS: + # await install_labels_async(cls, quiet=False) + + build_class_registry(cls) + + return cls + + +def build_class_registry(cls): + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) + + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if label_set not in adb._NODE_CLASS_REGISTRY: + adb._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined(cls, adb._NODE_CLASS_REGISTRY) + + +NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) + + +class StructuredNodeAsync(NodeBase): + """ + Base class for all node definitions to inherit from. + + If you want to create your own abstract classes set: + __abstract_node__ = True + """ + + # static properties + + __abstract_node__ = True + + # magic methods + + def __init__(self, *args, **kwargs): + if "deleted" in kwargs: + raise ValueError("deleted property is reserved for neomodel") + + for key, val in self.__all_relationships__: + self.__dict__[key] = val.build_manager(self, key) + + super().__init__(*args, **kwargs) + + def __eq__(self, other): + if not isinstance(other, (StructuredNodeAsync,)): + return False + if hasattr(self, "element_id") and hasattr(other, "element_id"): + return self.element_id == other.element_id + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return f"<{self.__class__.__name__}: {self}>" + + def __str__(self): + return repr(self.__properties__) + + # dynamic properties + + @classproperty + def nodes(cls): + """ + Returns a NodeSet object representing all nodes of the classes label + :return: NodeSet + :rtype: NodeSet + """ + from neomodel.match import NodeSet + + return NodeSet(cls) + + @property + def element_id(self): + if hasattr(self, "element_id_property"): + return ( + int(self.element_id_property) + if adb.database_version.startswith("4") + else self.element_id_property + ) + return None + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self): + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + # methods + + @classmethod + def _build_merge_query( + cls, merge_params, update_existing=False, lazy=False, relationship=None + ): + """ + Get a tuple of a CYPHER query and a params dict for the specified MERGE query. + + :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". + :type merge_params: list of dict + :param update_existing: True to update properties of existing nodes, default False to keep existing values. + :type update_existing: bool + :rtype: tuple + """ + query_params = dict(merge_params=merge_params) + n_merge_labels = ":".join(cls.inherited_labels()) + n_merge_prm = ", ".join( + ( + f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + for p in cls.__required_properties__ + ) + ) + n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" + if relationship is None: + # create "simple" unwind query + query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " + else: + # validate relationship + if not isinstance(relationship.source, StructuredNodeAsync): + raise ValueError( + f"relationship source [{repr(relationship.source)}] is not a StructuredNode" + ) + relation_type = relationship.definition.get("relation_type") + if not relation_type: + raise ValueError( + "No relation_type is specified on provided relationship" + ) + + from neomodel.match import _rel_helper + + query_params["source_id"] = relationship.source.element_id + query = f"MATCH (source:{relationship.source.__label__}) WHERE {adb.get_id_method()}(source) = $source_id\n " + query += "WITH source\n UNWIND $merge_params as params \n " + query += "MERGE " + query += _rel_helper( + lhs="source", + rhs=n_merge, + ident=None, + relation_type=relation_type, + direction=relationship.definition["direction"], + ) + + query += "ON CREATE SET n = params.create\n " + # if update_existing, write properties on match as well + if update_existing is True: + query += "ON MATCH SET n += params.update\n" + + # close query + if lazy: + query += f"RETURN {adb.get_id_method()}(n)" + else: + query += "RETURN n" + + return query, query_params + + @classmethod + async def create_async(cls, *props, **kwargs): + """ + Call to CREATE with parameters map. A new instance will be created and saved. + + :param props: dict of properties to create the nodes. + :type props: tuple + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type: bool + :rtype: list + """ + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + lazy = kwargs.get("lazy", False) + # create mapped query + query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" + + # close query + if lazy: + query += f" RETURN {adb.get_id_method()}(n)" + else: + query += " RETURN n" + + results = [] + for item in [ + cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props + ]: + node, _ = await adb.cypher_query_async(query, {"create_params": item}) + results.extend(node[0]) + + nodes = [cls.inflate(node) for node in results] + + if not lazy and hasattr(cls, "post_create"): + for node in nodes: + node.post_create() + + return nodes + + @classmethod + async def create_or_update_async(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, + this is an atomic operation. If an instance already exists all optional properties specified will be updated. + + Note that the post_create hook isn't called after create_or_update + + :param props: List of dict arguments to get or create the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query, make sure to update only explicitly specified properties + create_or_update_params = [] + for specified, deflated in [ + (p, cls.deflate(p, skip_empty=True)) for p in props + ]: + create_or_update_params.append( + { + "create": deflated, + "update": dict( + (k, v) for k, v in deflated.items() if k in specified + ), + } + ) + query, params = cls._build_merge_query( + create_or_update_params, + update_existing=True, + relationship=relationship, + lazy=lazy, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = await adb.cypher_query_async(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + async def cypher_async(self, query, params=None): + """ + Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. + + :param query: cypher query string + :type: string + :param params: query parameters + :type: dict + :return: list containing query results + :rtype: list + """ + self._pre_action_check("cypher") + params = params or {} + params.update({"self": self.element_id}) + return await adb.cypher_query_async(query, params) + + @hooks + async def delete_async(self): + """ + Delete a node and its relationships + + :return: True + """ + self._pre_action_check("delete") + await self.cypher_async( + f"MATCH (self) WHERE {adb.get_id_method()}(self)=$self DETACH DELETE self" + ) + delattr(self, "element_id_property") + self.deleted = True + return True + + @classmethod + async def get_or_create_async(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, + this is an atomic operation. + Parameters must contain all required properties, any non required properties with defaults will be generated. + + Note that the post_create hook isn't called after get_or_create + + :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create + the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query + get_or_create_params = [ + {"create": cls.deflate(p, skip_empty=True)} for p in props + ] + query, params = cls._build_merge_query( + get_or_create_params, relationship=relationship, lazy=lazy + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = await adb.cypher_query_async(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + @classmethod + def inflate(cls, node): + """ + Inflate a raw neo4j_driver node to a neomodel node + :param node: + :return: node object + """ + # support lazy loading + if isinstance(node, str) or isinstance(node, int): + snode = cls() + snode.element_id_property = node + else: + node_properties = _get_node_properties(node) + props = {} + for key, prop in cls.__all_properties__: + # map property name from database to object property + db_property = prop.db_property or key + + if db_property in node_properties: + props[key] = prop.inflate(node_properties[db_property], node) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + + snode = cls(**props) + snode.element_id_property = node.element_id + + return snode + + @classmethod + def inherited_labels(cls): + """ + Return list of labels from nodes class hierarchy. + + :return: list + """ + return [ + scls.__label__ + for scls in cls.mro() + if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") + ] + + @classmethod + def inherited_optional_labels(cls): + """ + Return list of optional labels from nodes class hierarchy. + + :return: list + :rtype: list + """ + return [ + label + for scls in cls.mro() + for label in getattr(scls, "__optional_labels__", []) + if not hasattr(scls, "__abstract_node__") + ] + + async def labels_async(self): + """ + Returns list of labels tied to the node from neo4j. + + :return: list of labels + :rtype: list + """ + self._pre_action_check("labels") + return await self.cypher_async( + f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" + )[0][0][0] + + def _pre_action_check(self, action): + if hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on deleted node" + ) + if not hasattr(self, "element_id"): + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on unsaved node" + ) + + async def refresh_async(self): + """ + Reload the node from neo4j + """ + self._pre_action_check("refresh") + if hasattr(self, "element_id"): + request = await self.cypher_async( + f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" + )[0] + if not request or not request[0]: + raise self.__class__.DoesNotExist("Can't refresh non existent node") + node = self.inflate(request[0][0]) + for key, val in node.__properties__.items(): + setattr(self, key, val) + else: + raise ValueError("Can't refresh unsaved node") + + @hooks + async def save_async(self): + """ + Save the node to neo4j or raise an exception + + :return: the node instance + """ + + # create or update instance node + if hasattr(self, "element_id_property"): + # update + params = self.deflate(self.__properties__, self) + query = f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self\n" + + if params: + query += "SET " + query += ",\n".join([f"n.{key} = ${key}" for key in params]) + query += "\n" + if self.inherited_labels(): + query += "\n".join( + [f"SET n:`{label}`" for label in self.inherited_labels()] + ) + await self.cypher_async(query, params) + elif hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.save() attempted on deleted node" + ) + else: # create + result = await self.create_async(self.__properties__) + created_node = result[0] + self.element_id_property = created_node.element_id + return self diff --git a/neomodel/config.py b/neomodel/config.py index b54aa806..26a0d626 100644 --- a/neomodel/config.py +++ b/neomodel/config.py @@ -1,6 +1,6 @@ import neo4j -from ._version import __version__ +from neomodel._version import __version__ AUTO_INSTALL_LABELS = False diff --git a/neomodel/contrib/__init__.py b/neomodel/contrib/__init__.py index 3be00b41..15a59660 100644 --- a/neomodel/contrib/__init__.py +++ b/neomodel/contrib/__init__.py @@ -1 +1 @@ -from .semi_structured import SemiStructuredNode +from neomodel.semi_structured import SemiStructuredNode diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/semi_structured.py index 9c719983..580514ba 100644 --- a/neomodel/contrib/semi_structured.py +++ b/neomodel/contrib/semi_structured.py @@ -1,9 +1,9 @@ -from neomodel.core import StructuredNode +from neomodel._async.core import StructuredNodeAsync from neomodel.exceptions import DeflateConflict, InflateConflict from neomodel.util import _get_node_properties -class SemiStructuredNode(StructuredNode): +class SemiStructuredNode(StructuredNodeAsync): """ A base class allowing properties to be stored on a node that aren't specified in its definition. Conflicting properties are signaled with the @@ -57,7 +57,7 @@ def inflate(cls, node): def deflate(cls, node_props, obj=None, skip_empty=False): deflated = super().deflate(node_props, obj, skip_empty=skip_empty) for key in [k for k in node_props if k not in deflated]: - if hasattr(cls, key) and (getattr(cls,key).required or not skip_empty): + if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): raise DeflateConflict(cls, key, deflated[key], obj.element_id) node_props.update(deflated) diff --git a/neomodel/core.py b/neomodel/core.py deleted file mode 100644 index 415a97af..00000000 --- a/neomodel/core.py +++ /dev/null @@ -1,784 +0,0 @@ -import sys -import warnings -from itertools import combinations - -from neo4j.exceptions import ClientError - -from neomodel import config -from neomodel.exceptions import ( - DoesNotExist, - FeatureNotSupported, - NodeClassAlreadyDefined, -) -from neomodel.hooks import hooks -from neomodel.properties import Property, PropertyManager -from neomodel.util import Database, _get_node_properties, _UnsavedNode, classproperty - -db = Database() - -RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" -INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" -CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" -STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" - - -def drop_constraints(quiet=True, stdout=None): - """ - Discover and drop all constraints. - - :type: bool - :return: None - """ - if not stdout or stdout is None: - stdout = sys.stdout - - results, meta = db.cypher_query("SHOW CONSTRAINTS") - - results_as_dict = [dict(zip(meta, row)) for row in results] - for constraint in results_as_dict: - db.cypher_query("DROP CONSTRAINT " + constraint["name"]) - if not quiet: - stdout.write( - ( - " - Dropping unique constraint and index" - f" on label {constraint['labelsOrTypes'][0]}" - f" with property {constraint['properties'][0]}.\n" - ) - ) - if not quiet: - stdout.write("\n") - - -def drop_indexes(quiet=True, stdout=None): - """ - Discover and drop all indexes, except the automatically created token lookup indexes. - - :type: bool - :return: None - """ - if not stdout or stdout is None: - stdout = sys.stdout - - indexes = db.list_indexes(exclude_token_lookup=True) - for index in indexes: - db.cypher_query("DROP INDEX " + index["name"]) - if not quiet: - stdout.write( - f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' - ) - if not quiet: - stdout.write("\n") - - -def remove_all_labels(stdout=None): - """ - Calls functions for dropping constraints and indexes. - - :param stdout: output stream - :return: None - """ - - if not stdout: - stdout = sys.stdout - - stdout.write("Dropping constraints...\n") - drop_constraints(quiet=False, stdout=stdout) - - stdout.write("Dropping indexes...\n") - drop_indexes(quiet=False, stdout=stdout) - - -def install_labels(cls, quiet=True, stdout=None): - """ - Setup labels with indexes and constraints for a given class - - :param cls: StructuredNode class - :type: class - :param quiet: (default true) enable standard output - :param stdout: stdout stream - :type: bool - :return: None - """ - if not stdout or stdout is None: - stdout = sys.stdout - - if not hasattr(cls, "__label__"): - if not quiet: - stdout.write( - f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" - ) - return - - for name, property in cls.defined_properties(aliases=False, rels=False).items(): - _install_node(cls, name, property, quiet, stdout) - - for _, relationship in cls.defined_properties( - aliases=False, rels=True, properties=False - ).items(): - _install_relationship(cls, relationship, quiet, stdout) - - -def _create_node_index(label: str, property_name: str, stdout): - try: - db.cypher_query( - f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - INDEX_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - - -def _create_node_constraint(label: str, property_name: str, stdout): - try: - db.cypher_query( - f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} - FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - CONSTRAINT_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - - -def _create_relationship_index(relationship_type: str, property_name: str, stdout): - try: - db.cypher_query( - f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - INDEX_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - - -def _create_relationship_constraint(relationship_type: str, property_name: str, stdout): - if db.version_is_higher_than("5.7"): - try: - db.cypher_query( - f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} - FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - CONSTRAINT_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - else: - raise FeatureNotSupported( - f"Unique indexes on relationships are not supported in Neo4j version {db.database_version}. Please upgrade to Neo4j 5.7 or higher." - ) - - -def _install_node(cls, name, property, quiet, stdout): - # Create indexes and constraints for node property - db_property = property.db_property or name - if property.index: - if not quiet: - stdout.write( - f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" - ) - _create_node_index( - label=cls.__label__, property_name=db_property, stdout=stdout - ) - - elif property.unique_index: - if not quiet: - stdout.write( - f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" - ) - _create_node_constraint( - label=cls.__label__, property_name=db_property, stdout=stdout - ) - - -def _install_relationship(cls, relationship, quiet, stdout): - # Create indexes and constraints for relationship property - relationship_cls = relationship.definition["model"] - if relationship_cls is not None: - relationship_type = relationship.definition["relation_type"] - for prop_name, property in relationship_cls.defined_properties( - aliases=False, rels=False - ).items(): - db_property = property.db_property or prop_name - if property.index: - if not quiet: - stdout.write( - f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" - ) - _create_relationship_index( - relationship_type=relationship_type, - property_name=db_property, - stdout=stdout, - ) - elif property.unique_index: - if not quiet: - stdout.write( - f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" - ) - _create_relationship_constraint( - relationship_type=relationship_type, - property_name=db_property, - stdout=stdout, - ) - - -def install_all_labels(stdout=None): - """ - Discover all subclasses of StructuredNode in your application and execute install_labels on each. - Note: code must be loaded (imported) in order for a class to be discovered. - - :param stdout: output stream - :return: None - """ - - if not stdout or stdout is None: - stdout = sys.stdout - - def subsub(cls): # recursively return all subclasses - subclasses = cls.__subclasses__() - if not subclasses: # base case: no more subclasses - return [] - return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] - - stdout.write("Setting up indexes and constraints...\n\n") - - i = 0 - for cls in subsub(StructuredNode): - stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") - install_labels(cls, quiet=False, stdout=stdout) - i += 1 - - if i: - stdout.write("\n") - - stdout.write(f"Finished {i} classes.\n") - - -class NodeMeta(type): - def __new__(mcs, name, bases, namespace): - namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) - cls = super().__new__(mcs, name, bases, namespace) - cls.DoesNotExist._model_class = cls - - if hasattr(cls, "__abstract_node__"): - delattr(cls, "__abstract_node__") - else: - if "deleted" in namespace: - raise ValueError( - "Property name 'deleted' is not allowed as it conflicts with neomodel internals." - ) - elif "id" in namespace: - raise ValueError( - """ - Property name 'id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as id is also a Neo4j internal. - """ - ) - elif "element_id" in namespace: - raise ValueError( - """ - Property name 'element_id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. - """ - ) - for key, value in ( - (x, y) for x, y in namespace.items() if isinstance(y, Property) - ): - value.name, value.owner = key, cls - if hasattr(value, "setup") and callable(value.setup): - value.setup() - - # cache various groups of properies - cls.__required_properties__ = tuple( - name - for name, property in cls.defined_properties( - aliases=False, rels=False - ).items() - if property.required or property.unique_index - ) - cls.__all_properties__ = tuple( - cls.defined_properties(aliases=False, rels=False).items() - ) - cls.__all_aliases__ = tuple( - cls.defined_properties(properties=False, rels=False).items() - ) - cls.__all_relationships__ = tuple( - cls.defined_properties(aliases=False, properties=False).items() - ) - - cls.__label__ = namespace.get("__label__", name) - cls.__optional_labels__ = namespace.get("__optional_labels__", []) - - if config.AUTO_INSTALL_LABELS: - install_labels(cls, quiet=False) - - build_class_registry(cls) - - return cls - - -def build_class_registry(cls): - base_label_set = frozenset(cls.inherited_labels()) - optional_label_set = set(cls.inherited_optional_labels()) - - # Construct all possible combinations of labels + optional labels - possible_label_combinations = [ - frozenset(set(x).union(base_label_set)) - for i in range(1, len(optional_label_set) + 1) - for x in combinations(optional_label_set, i) - ] - possible_label_combinations.append(base_label_set) - - for label_set in possible_label_combinations: - if label_set not in db._NODE_CLASS_REGISTRY: - db._NODE_CLASS_REGISTRY[label_set] = cls - else: - raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY) - - -NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) - - -class StructuredNode(NodeBase): - """ - Base class for all node definitions to inherit from. - - If you want to create your own abstract classes set: - __abstract_node__ = True - """ - - # static properties - - __abstract_node__ = True - - # magic methods - - def __init__(self, *args, **kwargs): - if "deleted" in kwargs: - raise ValueError("deleted property is reserved for neomodel") - - for key, val in self.__all_relationships__: - self.__dict__[key] = val.build_manager(self, key) - - super().__init__(*args, **kwargs) - - def __eq__(self, other): - if not isinstance(other, (StructuredNode,)): - return False - if hasattr(self, "element_id") and hasattr(other, "element_id"): - return self.element_id == other.element_id - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return f"<{self.__class__.__name__}: {self}>" - - def __str__(self): - return repr(self.__properties__) - - # dynamic properties - - @classproperty - def nodes(cls): - """ - Returns a NodeSet object representing all nodes of the classes label - :return: NodeSet - :rtype: NodeSet - """ - from .match import NodeSet - - return NodeSet(cls) - - @property - def element_id(self): - if hasattr(self, "element_id_property"): - return ( - int(self.element_id_property) - if db.database_version.startswith("4") - else self.element_id_property - ) - return None - - # Version 4.4 support - id is deprecated in version 5.x - @property - def id(self): - try: - return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) - - # methods - - @classmethod - def _build_merge_query( - cls, merge_params, update_existing=False, lazy=False, relationship=None - ): - """ - Get a tuple of a CYPHER query and a params dict for the specified MERGE query. - - :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". - :type merge_params: list of dict - :param update_existing: True to update properties of existing nodes, default False to keep existing values. - :type update_existing: bool - :rtype: tuple - """ - query_params = dict(merge_params=merge_params) - n_merge_labels = ":".join(cls.inherited_labels()) - n_merge_prm = ", ".join( - ( - f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" - for p in cls.__required_properties__ - ) - ) - n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" - if relationship is None: - # create "simple" unwind query - query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " - else: - # validate relationship - if not isinstance(relationship.source, StructuredNode): - raise ValueError( - f"relationship source [{repr(relationship.source)}] is not a StructuredNode" - ) - relation_type = relationship.definition.get("relation_type") - if not relation_type: - raise ValueError( - "No relation_type is specified on provided relationship" - ) - - from .match import _rel_helper - - query_params["source_id"] = relationship.source.element_id - query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " - query += "WITH source\n UNWIND $merge_params as params \n " - query += "MERGE " - query += _rel_helper( - lhs="source", - rhs=n_merge, - ident=None, - relation_type=relation_type, - direction=relationship.definition["direction"], - ) - - query += "ON CREATE SET n = params.create\n " - # if update_existing, write properties on match as well - if update_existing is True: - query += "ON MATCH SET n += params.update\n" - - # close query - if lazy: - query += f"RETURN {db.get_id_method()}(n)" - else: - query += "RETURN n" - - return query, query_params - - @classmethod - def create(cls, *props, **kwargs): - """ - Call to CREATE with parameters map. A new instance will be created and saved. - - :param props: dict of properties to create the nodes. - :type props: tuple - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :type: bool - :rtype: list - """ - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - lazy = kwargs.get("lazy", False) - # create mapped query - query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" - - # close query - if lazy: - query += f" RETURN {db.get_id_method()}(n)" - else: - query += " RETURN n" - - results = [] - for item in [ - cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props - ]: - node, _ = db.cypher_query(query, {"create_params": item}) - results.extend(node[0]) - - nodes = [cls.inflate(node) for node in results] - - if not lazy and hasattr(cls, "post_create"): - for node in nodes: - node.post_create() - - return nodes - - @classmethod - def create_or_update(cls, *props, **kwargs): - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, - this is an atomic operation. If an instance already exists all optional properties specified will be updated. - - Note that the post_create hook isn't called after create_or_update - - :param props: List of dict arguments to get or create the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy = kwargs.get("lazy", False) - relationship = kwargs.get("relationship") - - # build merge query, make sure to update only explicitly specified properties - create_or_update_params = [] - for specified, deflated in [ - (p, cls.deflate(p, skip_empty=True)) for p in props - ]: - create_or_update_params.append( - { - "create": deflated, - "update": dict( - (k, v) for k, v in deflated.items() if k in specified - ), - } - ) - query, params = cls._build_merge_query( - create_or_update_params, - update_existing=True, - relationship=relationship, - lazy=lazy, - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = db.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - def cypher(self, query, params=None): - """ - Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. - - :param query: cypher query string - :type: string - :param params: query parameters - :type: dict - :return: list containing query results - :rtype: list - """ - self._pre_action_check("cypher") - params = params or {} - params.update({"self": self.element_id}) - return db.cypher_query(query, params) - - @hooks - def delete(self): - """ - Delete a node and its relationships - - :return: True - """ - self._pre_action_check("delete") - self.cypher( - f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self" - ) - delattr(self, "element_id_property") - self.deleted = True - return True - - @classmethod - def get_or_create(cls, *props, **kwargs): - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, - this is an atomic operation. - Parameters must contain all required properties, any non required properties with defaults will be generated. - - Note that the post_create hook isn't called after get_or_create - - :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create - the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy = kwargs.get("lazy", False) - relationship = kwargs.get("relationship") - - # build merge query - get_or_create_params = [ - {"create": cls.deflate(p, skip_empty=True)} for p in props - ] - query, params = cls._build_merge_query( - get_or_create_params, relationship=relationship, lazy=lazy - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = db.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - @classmethod - def inflate(cls, node): - """ - Inflate a raw neo4j_driver node to a neomodel node - :param node: - :return: node object - """ - # support lazy loading - if isinstance(node, str) or isinstance(node, int): - snode = cls() - snode.element_id_property = node - else: - node_properties = _get_node_properties(node) - props = {} - for key, prop in cls.__all_properties__: - # map property name from database to object property - db_property = prop.db_property or key - - if db_property in node_properties: - props[key] = prop.inflate(node_properties[db_property], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - - snode = cls(**props) - snode.element_id_property = node.element_id - - return snode - - @classmethod - def inherited_labels(cls): - """ - Return list of labels from nodes class hierarchy. - - :return: list - """ - return [ - scls.__label__ - for scls in cls.mro() - if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") - ] - - @classmethod - def inherited_optional_labels(cls): - """ - Return list of optional labels from nodes class hierarchy. - - :return: list - :rtype: list - """ - return [ - label - for scls in cls.mro() - for label in getattr(scls, "__optional_labels__", []) - if not hasattr(scls, "__abstract_node__") - ] - - def labels(self): - """ - Returns list of labels tied to the node from neo4j. - - :return: list of labels - :rtype: list - """ - self._pre_action_check("labels") - return self.cypher( - f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" - )[0][0][0] - - def _pre_action_check(self, action): - if hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on deleted node" - ) - if not hasattr(self, "element_id"): - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on unsaved node" - ) - - def refresh(self): - """ - Reload the node from neo4j - """ - self._pre_action_check("refresh") - if hasattr(self, "element_id"): - request = self.cypher( - f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" - )[0] - if not request or not request[0]: - raise self.__class__.DoesNotExist("Can't refresh non existent node") - node = self.inflate(request[0][0]) - for key, val in node.__properties__.items(): - setattr(self, key, val) - else: - raise ValueError("Can't refresh unsaved node") - - @hooks - def save(self): - """ - Save the node to neo4j or raise an exception - - :return: the node instance - """ - - # create or update instance node - if hasattr(self, "element_id_property"): - # update - params = self.deflate(self.__properties__, self) - query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n" - - if params: - query += "SET " - query += ",\n".join([f"n.{key} = ${key}" for key in params]) - query += "\n" - if self.inherited_labels(): - query += "\n".join( - [f"SET n:`{label}`" for label in self.inherited_labels()] - ) - self.cypher(query, params) - elif hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.save() attempted on deleted node" - ) - else: # create - created_node = self.create(self.__properties__)[0] - self.element_id_property = created_node.element_id - return self diff --git a/neomodel/integration/numpy.py b/neomodel/integration/numpy.py index a04508c4..5dc6da80 100644 --- a/neomodel/integration/numpy.py +++ b/neomodel/integration/numpy.py @@ -7,7 +7,7 @@ Example: - >>> from neomodel import db + >>> from neomodel._async import db >>> from neomodel.integration.numpy import to_nparray >>> db.set_connection('bolt://neo4j:secret@localhost:7687') >>> df = to_nparray(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name")) diff --git a/neomodel/integration/pandas.py b/neomodel/integration/pandas.py index 1ad19871..845c8e50 100644 --- a/neomodel/integration/pandas.py +++ b/neomodel/integration/pandas.py @@ -7,7 +7,7 @@ Example: - >>> from neomodel import db + >>> from neomodel._async import db >>> from neomodel.integration.pandas import to_dataframe >>> db.set_connection('bolt://neo4j:secret@localhost:7687') >>> df = to_dataframe(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name")) diff --git a/neomodel/match.py b/neomodel/match.py index 2773d6e5..815b02ec 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from typing import Optional -from .core import StructuredNode, db -from .exceptions import MultipleNodesReturned -from .match_q import Q, QBase -from .properties import AliasProperty +from neomodel._async.core import StructuredNodeAsync, adb +from neomodel.exceptions import MultipleNodesReturned +from neomodel.match_q import Q, QBase +from neomodel.properties import AliasProperty OUTGOING, INCOMING, EITHER = 1, -1, 0 @@ -379,7 +379,7 @@ def build_source(self, source): return self.build_traversal(source) if isinstance(source, NodeSet): if inspect.isclass(source.source) and issubclass( - source.source, StructuredNode + source.source, StructuredNodeAsync ): ident = self.build_label(source.source.__label__.lower(), source.source) else: @@ -399,7 +399,7 @@ def build_source(self, source): ) return ident - if isinstance(source, StructuredNode): + if isinstance(source, StructuredNodeAsync): return self.build_node(source) raise ValueError("Unknown source type " + repr(source)) @@ -499,7 +499,7 @@ def build_node(self, node): place_holder = self._register_place_holder(ident) # Hack to emulate START to lookup a node by id - _node_lookup = f"MATCH ({ident}) WHERE {db.get_id_method()}({ident})=${place_holder} WITH {ident}" + _node_lookup = f"MATCH ({ident}) WHERE {adb.get_id_method()}({ident})=${place_holder} WITH {ident}" self._ast.lookup = _node_lookup self._query_params[place_holder] = node.element_id @@ -664,7 +664,7 @@ def _count(self): # drop additional_return to avoid unexpected result self._ast.additional_return = None query = self.build_query() - results, _ = db.cypher_query(query, self._query_params) + results, _ = adb.cypher_query_async(query, self._query_params) return int(results[0][0]) def _contains(self, node_element_id): @@ -674,7 +674,7 @@ def _contains(self, node_element_id): self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") - self._ast.where.append(f"{db.get_id_method()}({ident}) = ${place_holder}") + self._ast.where.append(f"{adb.get_id_method()}({ident}) = ${place_holder}") self._query_params[place_holder] = node_element_id return self._count() >= 1 @@ -683,15 +683,17 @@ def _execute(self, lazy=False): # inject id() into return or return_set if self._ast.return_clause: self._ast.return_clause = ( - f"{db.get_id_method()}({self._ast.return_clause})" + f"{adb.get_id_method()}({self._ast.return_clause})" ) else: self._ast.additional_return = [ - f"{db.get_id_method()}({item})" + f"{adb.get_id_method()}({item})" for item in self._ast.additional_return ] query = self.build_query() - results, _ = db.cypher_query(query, self._query_params, resolve_objects=True) + results, _ = adb.cypher_query_async( + query, self._query_params, resolve_objects=True + ) # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -732,7 +734,7 @@ def __nonzero__(self): return self.query_cls(self).build_ast()._count() > 0 def __contains__(self, obj): - if isinstance(obj, StructuredNode): + if isinstance(obj, StructuredNodeAsync): if hasattr(obj, "element_id") and obj.element_id is not None: return self.query_cls(self).build_ast()._contains(obj.element_id) raise ValueError("Unsaved node: " + repr(obj)) @@ -776,9 +778,9 @@ def __init__(self, source): self.source = source # could be a Traverse object or a node class if isinstance(source, Traversal): self.source_class = source.target_class - elif inspect.isclass(source) and issubclass(source, StructuredNode): + elif inspect.isclass(source) and issubclass(source, StructuredNodeAsync): self.source_class = source - elif isinstance(source, StructuredNode): + elif isinstance(source, StructuredNodeAsync): self.source_class = source.__class__ else: raise ValueError("Bad source for nodeset " + repr(source)) @@ -980,9 +982,9 @@ def __init__(self, source, name, definition): if isinstance(source, Traversal): self.source_class = source.target_class - elif inspect.isclass(source) and issubclass(source, StructuredNode): + elif inspect.isclass(source) and issubclass(source, StructuredNodeAsync): self.source_class = source - elif isinstance(source, StructuredNode): + elif isinstance(source, StructuredNodeAsync): self.source_class = source.__class__ elif isinstance(source, NodeSet): self.source_class = source.source_class diff --git a/neomodel/path.py b/neomodel/path.py index 5f063d11..85a92ec7 100644 --- a/neomodel/path.py +++ b/neomodel/path.py @@ -1,14 +1,15 @@ from neo4j.graph import Path -from .core import db -from .relationship import StructuredRel -from .exceptions import RelationshipClassNotDefined + +from neomodel._async.core import adb +from neomodel.exceptions import RelationshipClassNotDefined +from neomodel.relationship import StructuredRel class NeomodelPath(Path): """ Represents paths within neomodel. - This object is instantiated when you include whole paths in your ``cypher_query()`` + This object is instantiated when you include whole paths in your ``cypher_query()`` result sets and turn ``resolve_objects`` to True. That is, any query of the form: @@ -16,7 +17,7 @@ class NeomodelPath(Path): MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p - ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already + ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already resolved to their neomodel objects if such mapping is possible. @@ -25,23 +26,25 @@ class NeomodelPath(Path): :type nodes: List[StructuredNode] :type relationships: List[StructuredRel] """ + def __init__(self, a_neopath): - self._nodes=[] + self._nodes = [] self._relationships = [] for a_node in a_neopath.nodes: - self._nodes.append(db._object_resolution(a_node)) + self._nodes.append(adb._object_resolution(a_node)) for a_relationship in a_neopath.relationships: # This check is required here because if the relationship does not bear data # then it does not have an entry in the registry. In that case, we instantiate # an "unspecified" StructuredRel. rel_type = frozenset([a_relationship.type]) - if rel_type in db._NODE_CLASS_REGISTRY: - new_rel = db._object_resolution(a_relationship) + if rel_type in adb._NODE_CLASS_REGISTRY: + new_rel = adb._object_resolution(a_relationship) else: new_rel = StructuredRel.inflate(a_relationship) self._relationships.append(new_rel) + @property def nodes(self): return self._nodes @@ -49,5 +52,3 @@ def nodes(self): @property def relationships(self): return self._relationships - - diff --git a/neomodel/properties.py b/neomodel/properties.py index e28b4ead..737bbfac 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -67,7 +67,7 @@ def __init__(self, **kwargs): @property def __properties__(self): - from .relationship_manager import RelationshipManager + from neomodel.relationship_manager import RelationshipManager return dict( (name, value) @@ -101,7 +101,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): - from .relationship_manager import RelationshipDefinition + from neomodel.relationship_manager import RelationshipDefinition props = {} for baseclass in reversed(cls.__mro__): diff --git a/neomodel/relationship.py b/neomodel/relationship.py index 8df56c47..bea17660 100644 --- a/neomodel/relationship.py +++ b/neomodel/relationship.py @@ -1,8 +1,8 @@ import warnings -from .core import db -from .hooks import hooks -from .properties import Property, PropertyManager +from neomodel._async.core import adb +from neomodel.hooks import hooks +from neomodel.properties import Property, PropertyManager class RelationshipMeta(type): @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): def element_id(self): return ( int(self.element_id_property) - if db.database_version.startswith("4") + if adb.database_version.startswith("4") else self.element_id_property ) @@ -60,7 +60,7 @@ def element_id(self): def _start_node_element_id(self): return ( int(self._start_node_element_id_property) - if db.database_version.startswith("4") + if adb.database_version.startswith("4") else self._start_node_element_id_property ) @@ -68,7 +68,7 @@ def _start_node_element_id(self): def _end_node_element_id(self): return ( int(self._end_node_element_id_property) - if db.database_version.startswith("4") + if adb.database_version.startswith("4") else self._end_node_element_id_property ) @@ -110,11 +110,11 @@ def save(self): :return: self """ props = self.deflate(self.__properties__) - query = f"MATCH ()-[r]->() WHERE {db.get_id_method()}(r)=$self " + query = f"MATCH ()-[r]->() WHERE {adb.get_id_method()}(r)=$self " query += "".join([f" SET r.{key} = ${key}" for key in props]) props["self"] = self.element_id - db.cypher_query(query, props) + adb.cypher_query_async(query, props) return self @@ -124,10 +124,10 @@ def start_node(self): :return: StructuredNode """ - test = db.cypher_query( + test = adb.cypher_query_async( f""" MATCH (aNode) - WHERE {db.get_id_method()}(aNode)=$start_node_element_id + WHERE {adb.get_id_method()}(aNode)=$start_node_element_id RETURN aNode """, {"start_node_element_id": self._start_node_element_id}, @@ -141,10 +141,10 @@ def end_node(self): :return: StructuredNode """ - return db.cypher_query( + return adb.cypher_query_async( f""" MATCH (aNode) - WHERE {db.get_id_method()}(aNode)=$end_node_element_id + WHERE {adb.get_id_method()}(aNode)=$end_node_element_id RETURN aNode """, {"end_node_element_id": self._end_node_element_id}, diff --git a/neomodel/relationship_manager.py b/neomodel/relationship_manager.py index 1e9cf79e..103559a1 100644 --- a/neomodel/relationship_manager.py +++ b/neomodel/relationship_manager.py @@ -3,9 +3,9 @@ import sys from importlib import import_module -from .core import db -from .exceptions import NotConnected, RelationshipClassRedefined -from .match import ( +from neomodel._async.core import adb +from neomodel.exceptions import NotConnected, RelationshipClassRedefined +from neomodel.match import ( EITHER, INCOMING, OUTGOING, @@ -14,8 +14,8 @@ _rel_helper, _rel_merge_helper, ) -from .relationship import StructuredRel -from .util import _get_node_properties, enumerate_traceback +from neomodel.relationship import StructuredRel +from neomodel.util import _get_node_properties, enumerate_traceback # basestring python 3.x fallback try: @@ -120,7 +120,7 @@ def connect(self, node, properties=None): **self.definition, ) q = ( - f"MATCH (them), (us) WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self " + f"MATCH (them), (us) WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self " "MERGE" + new_rel ) @@ -164,7 +164,7 @@ def relationship(self, node): q = ( "MATCH " + my_rel - + f" WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r LIMIT 1" + + f" WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r LIMIT 1" ) rels = self.source.cypher(q, {"them": node.element_id})[0] if not rels: @@ -185,7 +185,7 @@ def all_relationships(self, node): self._check_node(node) my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) - q = f"MATCH {my_rel} WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r " + q = f"MATCH {my_rel} WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r " rels = self.source.cypher(q, {"them": node.element_id})[0] if not rels: return [] @@ -225,7 +225,7 @@ def reconnect(self, old_node, new_node): # get list of properties on the existing rel result, _ = self.source.cypher( f""" - MATCH (us), (old) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old + MATCH (us), (old) WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old MATCH {old_rel} RETURN r """, {"old": old_node.element_id}, @@ -240,7 +240,7 @@ def reconnect(self, old_node, new_node): new_rel = _rel_merge_helper(lhs="us", rhs="new", ident="r2", **self.definition) q = ( "MATCH (us), (old), (new) " - f"WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old and {db.get_id_method()}(new)=$new " + f"WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old and {adb.get_id_method()}(new)=$new " "MATCH " + old_rel ) q += " MERGE" + new_rel @@ -261,7 +261,7 @@ def disconnect(self, node): """ rel = _rel_helper(lhs="a", rhs="b", ident="r", **self.definition) q = f""" - MATCH (a), (b) WHERE {db.get_id_method()}(a)=$self and {db.get_id_method()}(b)=$them + MATCH (a), (b) WHERE {adb.get_id_method()}(a)=$self and {adb.get_id_method()}(b)=$them MATCH {rel} DELETE r """ self.source.cypher(q, {"them": node.element_id}) @@ -275,7 +275,7 @@ def disconnect_all(self): """ rhs = "b:" + self.definition["node_class"].__label__ rel = _rel_helper(lhs="a", rhs=rhs, ident="r", **self.definition) - q = f"MATCH (a) WHERE {db.get_id_method()}(a)=$self MATCH " + rel + " DELETE r" + q = f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH " + rel + " DELETE r" self.source.cypher(q) @check_source @@ -428,18 +428,18 @@ def __init__( # In this case, it has to be ensured that the class # that is overriding the relationship is a descendant # of the already existing class. - model_from_registry = db._NODE_CLASS_REGISTRY[label_set] + model_from_registry = adb._NODE_CLASS_REGISTRY[label_set] if not issubclass(model, model_from_registry): is_parent = issubclass(model_from_registry, model) if is_direct_subclass(model, StructuredRel) and not is_parent: raise RelationshipClassRedefined( - relation_type, db._NODE_CLASS_REGISTRY, model + relation_type, adb._NODE_CLASS_REGISTRY, model ) else: - db._NODE_CLASS_REGISTRY[label_set] = model + adb._NODE_CLASS_REGISTRY[label_set] = model except KeyError: # If the mapping does not exist then it is simply created. - db._NODE_CLASS_REGISTRY[label_set] = model + adb._NODE_CLASS_REGISTRY[label_set] = model def _validate_class(self, cls_name, model): if not isinstance(cls_name, (basestring, object)): diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index 1a24675a..c994dd0c 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -31,7 +31,7 @@ import textwrap from os import environ -from neomodel import db +from neomodel._async.core import adb IMPORTS = [] @@ -78,13 +78,13 @@ def get_properties_for_label(label): ORDER BY size(properties) DESC RETURN apoc.meta.cypher.types(properties(sampleNode)) AS properties LIMIT 1 """ - result, _ = db.cypher_query(query) + result, _ = adb.cypher_query_async(query) if result is not None and len(result) > 0: return result[0][0] @staticmethod def get_constraints_for_label(label): - constraints, meta_constraints = db.cypher_query( + constraints, meta_constraints = adb.cypher_query_async( f"SHOW CONSTRAINTS WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='UNIQUENESS'" ) constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] @@ -97,12 +97,12 @@ def get_constraints_for_label(label): @staticmethod def get_indexed_properties_for_label(label): - if db.version_is_higher_than("5.0"): - indexes, meta_indexes = db.cypher_query( + if adb.version_is_higher_than("5.0"): + indexes, meta_indexes = adb.cypher_query_async( f"SHOW INDEXES WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='RANGE' AND owningConstraint IS NULL" ) else: - indexes, meta_indexes = db.cypher_query( + indexes, meta_indexes = adb.cypher_query_async( f"SHOW INDEXES WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='BTREE' AND uniqueness='NONUNIQUE'" ) indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] @@ -123,12 +123,12 @@ def outgoing_relationships(cls, start_label): ORDER BY size(properties) DESC RETURN rel_type, target_label, apoc.meta.cypher.types(properties(sampleRel)) AS properties LIMIT 1 """ - result, _ = db.cypher_query(query) + result, _ = adb.cypher_query_async(query) return [(record[0], record[1], record[2]) for record in result] @staticmethod def get_constraints_for_type(rel_type): - constraints, meta_constraints = db.cypher_query( + constraints, meta_constraints = adb.cypher_query_async( f"SHOW CONSTRAINTS WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='RELATIONSHIP_UNIQUENESS'" ) constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] @@ -141,12 +141,12 @@ def get_constraints_for_type(rel_type): @staticmethod def get_indexed_properties_for_type(rel_type): - if db.version_is_higher_than("5.0"): - indexes, meta_indexes = db.cypher_query( + if adb.version_is_higher_than("5.0"): + indexes, meta_indexes = adb.cypher_query_async( f"SHOW INDEXES WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='RANGE' AND owningConstraint IS NULL" ) else: - indexes, meta_indexes = db.cypher_query( + indexes, meta_indexes = adb.cypher_query_async( f"SHOW INDEXES WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='BTREE' AND uniqueness='NONUNIQUE'" ) indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] @@ -160,7 +160,7 @@ def get_indexed_properties_for_type(rel_type): @staticmethod def infer_cardinality(rel_type, start_label): range_start_query = f"MATCH (n:`{start_label}`) WHERE NOT EXISTS ((n)-[:`{rel_type}`]->()) WITH n LIMIT 1 RETURN count(n)" - result, _ = db.cypher_query(range_start_query) + result, _ = adb.cypher_query_async(range_start_query) is_start_zero = result[0][0] > 0 range_end_query = f""" @@ -170,7 +170,7 @@ def infer_cardinality(rel_type, start_label): WITH n LIMIT 1 RETURN count(n) """ - result, _ = db.cypher_query(range_end_query) + result, _ = adb.cypher_query_async(range_end_query) is_end_one = result[0][0] == 0 cardinality = "Zero" if is_start_zero else "One" @@ -184,7 +184,7 @@ def infer_cardinality(rel_type, start_label): def get_node_labels(): query = "CALL db.labels()" - result, _ = db.cypher_query(query) + result, _ = adb.cypher_query_async(query) return [record[0] for record in result] @@ -234,7 +234,7 @@ def build_rel_type_definition(label, outgoing_relationships, defined_rel_types): unique_properties = ( RelationshipInspector.get_constraints_for_type(rel_type) - if db.version_is_higher_than("5.7") + if adb.version_is_higher_than("5.7") else [] ) indexed_properties = RelationshipInspector.get_indexed_properties_for_type( @@ -268,7 +268,7 @@ def build_rel_type_definition(label, outgoing_relationships, defined_rel_types): def inspect_database(bolt_url): # Connect to the database print(f"Connecting to {bolt_url}") - db.set_connection(bolt_url) + adb.set_connection_async(bolt_url) node_labels = get_node_labels() defined_rel_types = [] diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index 8bd5119f..5b7f65a8 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -32,7 +32,7 @@ from importlib import import_module from os import environ, path -from .. import db, install_all_labels +from neomodel._async.core import adb def load_python_module_or_file(name): @@ -109,9 +109,9 @@ def main(): # Connect after to override any code in the module that may set the connection print(f"Connecting to {bolt_url}") - db.set_connection(url=bolt_url) + adb.set_connection_async(url=bolt_url) - install_all_labels() + adb.install_all_labels_async() if __name__ == "__main__": diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 1ad6cc34..8eb7273b 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -27,7 +27,7 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from os import environ -from .. import db, remove_all_labels +from neomodel._async.core import adb def main(): @@ -61,9 +61,9 @@ def main(): # Connect after to override any code in the module that may set the connection print(f"Connecting to {bolt_url}") - db.set_connection(url=bolt_url) + adb.set_connection_async(url=bolt_url) - remove_all_labels() + adb.remove_all_labels_async() if __name__ == "__main__": diff --git a/neomodel/util.py b/neomodel/util.py index cf4230e3..6dd6c95d 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -1,630 +1,4 @@ -import logging -import os -import sys -import time import warnings -from threading import local -from typing import Optional, Sequence -from urllib.parse import quote, unquote, urlparse - -from neo4j import DEFAULT_DATABASE, Driver, GraphDatabase, basic_auth -from neo4j.api import Bookmarks -from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired -from neo4j.graph import Node, Path, Relationship - -from neomodel import config, core -from neomodel.exceptions import ( - ConstraintValidationFailed, - FeatureNotSupported, - NodeClassNotDefined, - RelationshipClassNotDefined, - UniqueProperty, -) - -logger = logging.getLogger(__name__) - - -# make sure the connection url has been set prior to executing the wrapped function -def ensure_connection(func): - def wrapper(self, *args, **kwargs): - # Sort out where to find url - if hasattr(self, "db"): - _db = self.db - else: - _db = self - - if not _db.driver: - if hasattr(config, "DRIVER") and config.DRIVER: - _db.set_connection(driver=config.DRIVER) - elif config.DATABASE_URL: - _db.set_connection(url=config.DATABASE_URL) - - return func(self, *args, **kwargs) - - return wrapper - - -def change_neo4j_password(db, user, new_password): - db.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") - - -def clear_neo4j_database(db, clear_constraints=False, clear_indexes=False): - db.cypher_query( - """ - MATCH (a) - CALL { WITH a DETACH DELETE a } - IN TRANSACTIONS OF 5000 rows - """ - ) - if clear_constraints: - core.drop_constraints() - if clear_indexes: - core.drop_indexes() - - -class Database(local): - """ - A singleton object via which all operations from neomodel to the Neo4j backend are handled with. - """ - - _NODE_CLASS_REGISTRY = {} - - def __init__(self): - self._active_transaction = None - self.url = None - self.driver = None - self._session = None - self._pid = None - self._database_name = DEFAULT_DATABASE - self.protocol_version = None - self._database_version = None - self._database_edition = None - self.impersonated_user = None - - def set_connection(self, url: str = None, driver: Driver = None): - """ - Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. - - Args: - url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname. - When provided, a Neo4j driver instance will be created by neomodel. - - driver (neo4j.Driver): Optionally, a pre-created driver instance. - When provided, neomodel will not create a driver instance but use this one instead. - """ - if driver: - self.driver = driver - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME - elif url: - self._parse_driver_from_url(url=url) - - self._pid = os.getpid() - self._active_transaction = None - # Set to default database if it hasn't been set before - if self._database_name is None: - self._database_name = DEFAULT_DATABASE - - # Getting the information about the database version requires a connection to the database - self._database_version = None - self._database_edition = None - self._update_database_version() - - def _parse_driver_from_url(self, url: str) -> None: - """Parse the driver information from the given URL and initialize the driver. - - Args: - url (str): The URL to parse. - - Raises: - ValueError: If the URL format is not as expected. - - Returns: - None - Sets the driver and database_name as class properties - """ - p_start = url.replace(":", "", 1).find(":") + 2 - p_end = url.rfind("@") - password = url[p_start:p_end] - url = url.replace(password, quote(password)) - parsed_url = urlparse(url) - - valid_schemas = [ - "bolt", - "bolt+s", - "bolt+ssc", - "bolt+routing", - "neo4j", - "neo4j+s", - "neo4j+ssc", - ] - - if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas: - credentials, hostname = parsed_url.netloc.rsplit("@", 1) - username, password = credentials.split(":") - password = unquote(password) - database_name = parsed_url.path.strip("/") - else: - raise ValueError( - f"Expecting url format: bolt://user:password@localhost:7687 got {url}" - ) - - options = { - "auth": basic_auth(username, password), - "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, - "connection_timeout": config.CONNECTION_TIMEOUT, - "keep_alive": config.KEEP_ALIVE, - "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, - "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, - "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, - "resolver": config.RESOLVER, - "user_agent": config.USER_AGENT, - } - - if "+s" not in parsed_url.scheme: - options["encrypted"] = config.ENCRYPTED - options["trusted_certificates"] = config.TRUSTED_CERTIFICATES - - self.driver = GraphDatabase.driver( - parsed_url.scheme + "://" + hostname, **options - ) - self.url = url - # The database name can be provided through the url or the config - if database_name == "": - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME - else: - self._database_name = database_name - - def close_connection(self): - """ - Closes the currently open driver. - The driver should always be closed at the end of the application's lifecyle. - """ - self._database_version = None - self._database_edition = None - self._database_name = None - self.driver.close() - self.driver = None - - @property - def database_version(self): - if self._database_version is None: - self._update_database_version() - - return self._database_version - - @property - def database_edition(self): - if self._database_edition is None: - self._update_database_version() - - return self._database_edition - - @property - def transaction(self): - """ - Returns the current transaction object - """ - return TransactionProxy(self) - - @property - def write_transaction(self): - return TransactionProxy(self, access_mode="WRITE") - - @property - def read_transaction(self): - return TransactionProxy(self, access_mode="READ") - - def impersonate(self, user: str) -> "ImpersonationHandler": - """All queries executed within this context manager will be executed as impersonated user - - Args: - user (str): User to impersonate - - Returns: - ImpersonationHandler: Context manager to set/unset the user to impersonate - """ - if self.database_edition != "enterprise": - raise FeatureNotSupported( - "Impersonation is only available in Neo4j Enterprise edition" - ) - return ImpersonationHandler(self, impersonated_user=user) - - @ensure_connection - def begin(self, access_mode=None, **parameters): - """ - Begins a new transaction. Raises SystemError if a transaction is already active. - """ - if ( - hasattr(self, "_active_transaction") - and self._active_transaction is not None - ): - raise SystemError("Transaction in progress") - self._session = self.driver.session( - default_access_mode=access_mode, - database=self._database_name, - impersonated_user=self.impersonated_user, - **parameters, - ) - self._active_transaction = self._session.begin_transaction() - - @ensure_connection - def commit(self): - """ - Commits the current transaction and closes its session - - :return: last_bookmarks - """ - try: - self._active_transaction.commit() - last_bookmarks = self._session.last_bookmarks() - finally: - # In case when something went wrong during - # committing changes to the database - # we have to close an active transaction and session. - self._active_transaction.close() - self._session.close() - self._active_transaction = None - self._session = None - - return last_bookmarks - - @ensure_connection - def rollback(self): - """ - Rolls back the current transaction and closes its session - """ - try: - self._active_transaction.rollback() - finally: - # In case when something went wrong during changes rollback, - # we have to close an active transaction and session - self._active_transaction.close() - self._session.close() - self._active_transaction = None - self._session = None - - def _update_database_version(self): - """ - Updates the database server information when it is required - """ - try: - results = self.cypher_query( - "CALL dbms.components() yield versions, edition return versions[0], edition" - ) - self._database_version = results[0][0][0] - self._database_edition = results[0][0][1] - except ServiceUnavailable: - # The database server is not running yet - pass - - def _object_resolution(self, object_to_resolve): - """ - Performs in place automatic object resolution on a result - returned by cypher_query. - - The function operates recursively in order to be able to resolve Nodes - within nested list structures and Path objects. Not meant to be called - directly, used primarily by _result_resolution. - - :param object_to_resolve: A result as returned by cypher_query. - :type Any: - - :return: An instantiated object. - """ - # Below is the original comment that came with the code extracted in - # this method. It is not very clear but I decided to keep it just in - # case - # - # - # For some reason, while the type of `a_result_attribute[1]` - # as reported by the neo4j driver is `Node` for Node-type data - # retrieved from the database. - # When the retrieved data are Relationship-Type, - # the returned type is `abc.[REL_LABEL]` which is however - # a descendant of Relationship. - # Consequently, the type checking was changed for both - # Node, Relationship objects - if isinstance(object_to_resolve, Node): - return self._NODE_CLASS_REGISTRY[ - frozenset(object_to_resolve.labels) - ].inflate(object_to_resolve) - - if isinstance(object_to_resolve, Relationship): - rel_type = frozenset([object_to_resolve.type]) - return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) - - if isinstance(object_to_resolve, Path): - from .path import NeomodelPath - - return NeomodelPath(object_to_resolve) - - if isinstance(object_to_resolve, list): - return self._result_resolution([object_to_resolve]) - - return object_to_resolve - - def _result_resolution(self, result_list): - """ - Performs in place automatic object resolution on a set of results - returned by cypher_query. - - The function operates recursively in order to be able to resolve Nodes - within nested list structures. Not meant to be called directly, - used primarily by cypher_query. - - :param result_list: A list of results as returned by cypher_query. - :type list: - - :return: A list of instantiated objects. - """ - - # Object resolution occurs in-place - for a_result_item in enumerate(result_list): - for a_result_attribute in enumerate(a_result_item[1]): - try: - # Primitive types should remain primitive types, - # Nodes to be resolved to native objects - resolved_object = a_result_attribute[1] - - resolved_object = self._object_resolution(resolved_object) - - result_list[a_result_item[0]][ - a_result_attribute[0] - ] = resolved_object - - except KeyError as exc: - # Not being able to match the label set of a node with a known object results - # in a KeyError in the internal dictionary used for resolution. If it is impossible - # to match, then raise an exception with more details about the error. - if isinstance(a_result_attribute[1], Node): - raise NodeClassNotDefined( - a_result_attribute[1], self._NODE_CLASS_REGISTRY - ) from exc - - if isinstance(a_result_attribute[1], Relationship): - raise RelationshipClassNotDefined( - a_result_attribute[1], self._NODE_CLASS_REGISTRY - ) from exc - - return result_list - - @ensure_connection - def cypher_query( - self, - query, - params=None, - handle_unique=True, - retry_on_session_expire=False, - resolve_objects=False, - ): - """ - Runs a query on the database and returns a list of results and their headers. - - :param query: A CYPHER query - :type: str - :param params: Dictionary of parameters - :type: dict - :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors - :type: bool - :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired. - If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance. - :type: bool - :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically - :type: bool - """ - - if self._active_transaction: - # Use current session is a transaction is currently active - results, meta = self._run_cypher_query( - self._active_transaction, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ) - else: - # Otherwise create a new session in a with to dispose of it after it has been run - with self.driver.session( - database=self._database_name, impersonated_user=self.impersonated_user - ) as session: - results, meta = self._run_cypher_query( - session, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ) - - return results, meta - - def _run_cypher_query( - self, - session, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ): - try: - # Retrieve the data - start = time.time() - response = session.run(query, params) - results, meta = [list(r.values()) for r in response], response.keys() - end = time.time() - - if resolve_objects: - # Do any automatic resolution required - results = self._result_resolution(results) - - except ClientError as e: - if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - if "already exists with label" in e.message and handle_unique: - raise UniqueProperty(e.message) from e - - raise ConstraintValidationFailed(e.message) from e - exc_info = sys.exc_info() - raise exc_info[1].with_traceback(exc_info[2]) - except SessionExpired: - if retry_on_session_expire: - self.set_connection(url=self.url) - return self.cypher_query( - query=query, - params=params, - handle_unique=handle_unique, - retry_on_session_expire=False, - ) - raise - - tte = end - start - if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float( - os.environ.get("NEOMODEL_SLOW_QUERIES", 0) - ): - logger.debug( - "query: " - + query - + "\nparams: " - + repr(params) - + f"\ntook: {tte:.2g}s\n" - ) - - return results, meta - - def get_id_method(self) -> str: - if self.database_version.startswith("4"): - return "id" - else: - return "elementId" - - def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: - """Returns all indexes existing in the database - - Arguments: - exclude_token_lookup[bool]: Exclude automatically create token lookup indexes - - Returns: - Sequence[dict]: List of dictionaries, each entry being an index definition - """ - indexes, meta_indexes = self.cypher_query("SHOW INDEXES") - indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] - - if exclude_token_lookup: - indexes_as_dict = [ - obj for obj in indexes_as_dict if obj["type"] != "LOOKUP" - ] - - return indexes_as_dict - - def list_constraints(self) -> Sequence[dict]: - """Returns all constraints existing in the database - - Returns: - Sequence[dict]: List of dictionaries, each entry being a constraint definition - """ - constraints, meta_constraints = self.cypher_query("SHOW CONSTRAINTS") - constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] - - return constraints_as_dict - - def version_is_higher_than(self, version_tag: str) -> bool: - """Returns true if the database version is higher or equal to a given tag - - Args: - version_tag (str): The version to compare against - - Returns: - bool: True if the database version is higher or equal to the given version - """ - return version_tag_to_integer(self.database_version) >= version_tag_to_integer( - version_tag - ) - - def edition_is_enterprise(self) -> bool: - """Returns true if the database edition is enterprise - - Returns: - bool: True if the database edition is enterprise - """ - return self.database_edition == "enterprise" - - -class TransactionProxy: - bookmarks: Optional[Bookmarks] = None - - def __init__(self, db, access_mode=None): - self.db = db - self.access_mode = access_mode - - @ensure_connection - def __enter__(self): - self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) - self.bookmarks = None - return self - - def __exit__(self, exc_type, exc_value, traceback): - if exc_value: - self.db.rollback() - - if ( - exc_type is ClientError - and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" - ): - raise UniqueProperty(exc_value.message) - - if not exc_value: - self.last_bookmark = self.db.commit() - - def __call__(self, func): - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapper - - @property - def with_bookmark(self): - return BookmarkingTransactionProxy(self.db, self.access_mode) - - -class ImpersonationHandler: - def __init__(self, db, impersonated_user: str): - self.db = db - self.impersonated_user = impersonated_user - - def __enter__(self): - self.db.impersonated_user = self.impersonated_user - return self - - def __exit__(self, exception_type, exception_value, exception_traceback): - self.db.impersonated_user = None - - print("\nException type:", exception_type) - print("\nException value:", exception_value) - print("\nTraceback:", exception_traceback) - - def __call__(self, func): - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapper - - -class BookmarkingTransactionProxy(TransactionProxy): - def __call__(self, func): - def wrapper(*args, **kwargs): - self.bookmarks = kwargs.pop("bookmarks", None) - - with self: - result = func(*args, **kwargs) - self.last_bookmark = None - - return result, self.last_bookmark - - return wrapper def deprecated(message): diff --git a/pyproject.toml b/pyproject.toml index d635911d..83eaad9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ changelog = "https://github.com/neo4j-contrib/neomodel/releases" [project.optional-dependencies] dev = [ "pytest>=7.1", + "pytest-asyncio", "pytest-cov>=4.0", "pre-commit", "black", diff --git a/test/async_/conftest.py b/test/async_/conftest.py new file mode 100644 index 00000000..587559fa --- /dev/null +++ b/test/async_/conftest.py @@ -0,0 +1,48 @@ +import os +import warnings + +import pytest_asyncio + +from neomodel import config +from neomodel._async.core import adb + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def setup_neo4j_session(request): + """ + Provides initial connection to the database and sets up the rest of the test suite + + :param request: The request object. Please see `_ + :type Request object: For more information please see `_ + """ + + warnings.simplefilter("default") + + config.DATABASE_URL = os.environ.get( + "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" + ) + config.AUTO_INSTALL_LABELS = True + + # Clear the database if required + database_is_populated, _ = await adb.cypher_query_async( + "MATCH (a) return count(a)>0 as database_is_populated" + ) + if database_is_populated[0][0] and not request.config.getoption("resetdb"): + raise SystemError( + "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." + ) + else: + await adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) + + await adb.cypher_query_async( + "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" + ) + if adb.database_edition == "enterprise": + await adb.cypher_query_async("GRANT ROLE publisher TO troygreene") + await adb.cypher_query_async("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def cleanup(): + yield + await adb.close_connection_async() diff --git a/test/conftest.py b/test/conftest.py index 7ec261d6..48d3088b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,11 +1,9 @@ from __future__ import print_function import os -import warnings import pytest -from neomodel import clear_neo4j_database, config, db from neomodel.util import version_tag_to_integer NEO4J_URL = os.environ.get("NEO4J_URL", "bolt://localhost:7687") @@ -50,46 +48,6 @@ def pytest_collection_modifyitems(items): items[:] = new_order -@pytest.hookimpl -def pytest_sessionstart(session): - """ - Provides initial connection to the database and sets up the rest of the test suite - - :param session: The session object. Please see `_ - :type Session object: For more information please see `_ - """ - - warnings.simplefilter("default") - - config.DATABASE_URL = os.environ.get( - "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" - ) - config.AUTO_INSTALL_LABELS = True - - # Clear the database if required - database_is_populated, _ = db.cypher_query( - "MATCH (a) return count(a)>0 as database_is_populated" - ) - if database_is_populated[0][0] and not session.config.getoption("resetdb"): - raise SystemError( - "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." - ) - else: - clear_neo4j_database(db, clear_constraints=True, clear_indexes=True) - - db.cypher_query( - "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" - ) - if db.database_edition == "enterprise": - db.cypher_query("GRANT ROLE publisher TO troygreene") - db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") - - -@pytest.hookimpl -def pytest_unconfigure(): - db.close_connection() - - def check_and_skip_neo4j_least_version(required_least_neo4j_version, message): """ Checks if the NEO4J_VERSION is at least `required_least_neo4j_version` and skips a test if not. @@ -112,8 +70,3 @@ def check_and_skip_neo4j_least_version(required_least_neo4j_version, message): "Neo4j version: {}. {}." "Skipping test.".format(os.environ["NEO4J_VERSION"], message) ) - - -@pytest.fixture -def skip_neo4j_before_330(): - check_and_skip_neo4j_least_version(330, "Neo4J version does not support this test") diff --git a/test/test_alias.py b/test/test_alias.py index c63119aa..78f39969 100644 --- a/test/test_alias.py +++ b/test/test_alias.py @@ -1,4 +1,4 @@ -from neomodel import AliasProperty, StringProperty, StructuredNode +from neomodel import AliasProperty, StringProperty, StructuredNodeAsync class MagicProperty(AliasProperty): @@ -6,20 +6,20 @@ def setup(self): self.owner.setup_hook_called = True -class AliasTestNode(StructuredNode): +class AliasTestNode(StructuredNodeAsync): name = StringProperty(unique_index=True) full_name = AliasProperty(to="name") long_name = MagicProperty(to="name") def test_property_setup_hook(): - tim = AliasTestNode(long_name="tim").save() + tim = AliasTestNode(long_name="tim").save_async() assert AliasTestNode.setup_hook_called assert tim.name == "tim" def test_alias(): - jim = AliasTestNode(full_name="Jim").save() + jim = AliasTestNode(full_name="Jim").save_async() assert jim.name == "Jim" assert jim.full_name == "Jim" assert "full_name" not in AliasTestNode.deflate(jim.__properties__) diff --git a/test/test_batch.py b/test/test_batch.py index c2d1ec86..cf0b6a4a 100644 --- a/test/test_batch.py +++ b/test/test_batch.py @@ -5,7 +5,7 @@ RelationshipFrom, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, UniqueIdProperty, config, ) @@ -14,30 +14,34 @@ config.AUTO_INSTALL_LABELS = True -class UniqueUser(StructuredNode): +class UniqueUser(StructuredNodeAsync): uid = UniqueIdProperty() name = StringProperty() age = IntegerProperty() def test_unique_id_property_batch(): - users = UniqueUser.create({"name": "bob", "age": 2}, {"name": "ben", "age": 3}) + users = UniqueUser.create_async( + {"name": "bob", "age": 2}, {"name": "ben", "age": 3} + ) assert users[0].uid != users[1].uid - users = UniqueUser.get_or_create({"uid": users[0].uid}, {"name": "bill", "age": 4}) + users = UniqueUser.get_or_create_async( + {"uid": users[0].uid}, {"name": "bill", "age": 4} + ) assert users[0].name == "bob" assert users[1].uid -class Customer(StructuredNode): +class Customer(StructuredNodeAsync): email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) def test_batch_create(): - users = Customer.create( + users = Customer.create_async( {"email": "jim1@aol.com", "age": 11}, {"email": "jim2@aol.com", "age": 7}, {"email": "jim3@aol.com", "age": 9}, @@ -52,7 +56,7 @@ def test_batch_create(): def test_batch_create_or_update(): - users = Customer.create_or_update( + users = Customer.create_or_update_async( {"email": "merge1@aol.com", "age": 11}, {"email": "merge2@aol.com"}, {"email": "merge3@aol.com", "age": 1}, @@ -62,7 +66,7 @@ def test_batch_create_or_update(): assert users[1] == users[3] assert Customer.nodes.get(email="merge1@aol.com").age == 11 - more_users = Customer.create_or_update( + more_users = Customer.create_or_update_async( {"email": "merge1@aol.com", "age": 22}, {"email": "merge4@aol.com", "age": None}, ) @@ -74,7 +78,7 @@ def test_batch_create_or_update(): def test_batch_validation(): # test validation in batch create with raises(DeflateError): - Customer.create( + Customer.create_async( {"email": "jim1@aol.com", "age": "x"}, ) @@ -83,12 +87,12 @@ def test_batch_index_violation(): for u in Customer.nodes.all(): u.delete() - users = Customer.create( + users = Customer.create_async( {"email": "jim6@aol.com", "age": 3}, ) assert users with raises(UniqueProperty): - Customer.create( + Customer.create_async( {"email": "jim6@aol.com", "age": 3}, {"email": "jim7@aol.com", "age": 5}, ) @@ -97,22 +101,22 @@ def test_batch_index_violation(): assert not Customer.nodes.filter(email="jim7@aol.com") -class Dog(StructuredNode): +class Dog(StructuredNodeAsync): name = StringProperty(required=True) owner = RelationshipTo("Person", "owner") -class Person(StructuredNode): +class Person(StructuredNodeAsync): name = StringProperty(unique_index=True) pets = RelationshipFrom("Dog", "owner") def test_get_or_create_with_rel(): - bob = Person.get_or_create({"name": "Bob"})[0] - bobs_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) + bob = Person.get_or_create_async({"name": "Bob"})[0] + bobs_gizmo = Dog.get_or_create_async({"name": "Gizmo"}, relationship=bob.pets) - tim = Person.get_or_create({"name": "Tim"})[0] - tims_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) + tim = Person.get_or_create_async({"name": "Tim"})[0] + tims_gizmo = Dog.get_or_create_async({"name": "Gizmo"}, relationship=tim.pets) # not the same gizmo assert bobs_gizmo[0] != tims_gizmo[0] diff --git a/test/test_cardinality.py b/test/test_cardinality.py index 3c850db0..119aef6c 100644 --- a/test/test_cardinality.py +++ b/test/test_cardinality.py @@ -8,26 +8,26 @@ OneOrMore, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, ZeroOrMore, ZeroOrOne, - db, + adb, ) -class HairDryer(StructuredNode): +class HairDryer(StructuredNodeAsync): version = IntegerProperty() -class ScrewDriver(StructuredNode): +class ScrewDriver(StructuredNodeAsync): version = IntegerProperty() -class Car(StructuredNode): +class Car(StructuredNodeAsync): version = IntegerProperty() -class Monkey(StructuredNode): +class Monkey(StructuredNodeAsync): name = StringProperty() dryers = RelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) driver = RelationshipTo("ScrewDriver", "HAS_SCREWDRIVER", cardinality=ZeroOrOne) @@ -35,15 +35,15 @@ class Monkey(StructuredNode): toothbrush = RelationshipTo("ToothBrush", "HAS_TOOTHBRUSH", cardinality=One) -class ToothBrush(StructuredNode): +class ToothBrush(StructuredNodeAsync): name = StringProperty() def test_cardinality_zero_or_more(): - m = Monkey(name="tim").save() + m = Monkey(name="tim").save_async() assert m.dryers.all() == [] assert m.dryers.single() is None - h = HairDryer(version=1).save() + h = HairDryer(version=1).save_async() m.dryers.connect(h) assert len(m.dryers.all()) == 1 @@ -53,7 +53,7 @@ def test_cardinality_zero_or_more(): assert m.dryers.all() == [] assert m.dryers.single() is None - h2 = HairDryer(version=2).save() + h2 = HairDryer(version=2).save_async() m.dryers.connect(h) m.dryers.connect(h2) m.dryers.disconnect_all() @@ -62,16 +62,16 @@ def test_cardinality_zero_or_more(): def test_cardinality_zero_or_one(): - m = Monkey(name="bob").save() + m = Monkey(name="bob").save_async() assert m.driver.all() == [] assert m.driver.single() is None - h = ScrewDriver(version=1).save() + h = ScrewDriver(version=1).save_async() m.driver.connect(h) assert len(m.driver.all()) == 1 assert m.driver.single().version == 1 - j = ScrewDriver(version=2).save() + j = ScrewDriver(version=2).save_async() with raises(AttemptedCardinalityViolation): m.driver.connect(j) @@ -80,7 +80,7 @@ def test_cardinality_zero_or_one(): # Forcing creation of a second ToothBrush to go around # AttemptedCardinalityViolation - db.cypher_query( + adb.cypher_query( """ MATCH (m:Monkey WHERE m.name="bob") CREATE (s:ScrewDriver {version:3}) @@ -95,7 +95,7 @@ def test_cardinality_zero_or_one(): def test_cardinality_one_or_more(): - m = Monkey(name="jerry").save() + m = Monkey(name="jerry").save_async() with raises(CardinalityViolation): m.car.all() @@ -103,7 +103,7 @@ def test_cardinality_one_or_more(): with raises(CardinalityViolation): m.car.single() - c = Car(version=2).save() + c = Car(version=2).save_async() m.car.connect(c) assert m.car.single().version == 2 @@ -113,7 +113,7 @@ def test_cardinality_one_or_more(): with raises(AttemptedCardinalityViolation): m.car.disconnect(c) - d = Car(version=3).save() + d = Car(version=3).save_async() m.car.connect(d) cars = m.car.all() assert len(cars) == 2 @@ -124,7 +124,7 @@ def test_cardinality_one_or_more(): def test_cardinality_one(): - m = Monkey(name="jerry").save() + m = Monkey(name="jerry").save_async() with raises( CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: none." @@ -134,11 +134,11 @@ def test_cardinality_one(): with raises(CardinalityViolation): m.toothbrush.single() - b = ToothBrush(name="Jim").save() + b = ToothBrush(name="Jim").save_async() m.toothbrush.connect(b) assert m.toothbrush.single().name == "Jim" - x = ToothBrush(name="Jim").save + x = ToothBrush(name="Jim").save_async with raises(AttemptedCardinalityViolation): m.toothbrush.connect(x) @@ -150,7 +150,7 @@ def test_cardinality_one(): # Forcing creation of a second ToothBrush to go around # AttemptedCardinalityViolation - db.cypher_query( + adb.cypher_query( """ MATCH (m:Monkey WHERE m.name="jerry") CREATE (t:ToothBrush {name:"Jim"}) diff --git a/test/test_connection.py b/test/test_connection.py index fb3524bb..26e74e07 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -5,9 +5,8 @@ from neo4j import GraphDatabase from neo4j.debug import watch -from neomodel import StringProperty, StructuredNode, config, db - -from .conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME +from neomodel import StringProperty, StructuredNodeAsync, adb, config +from neomodel.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME @pytest.fixture(autouse=True) @@ -15,8 +14,8 @@ def setup_teardown(): yield # Teardown actions after tests have run # Reconnect to initial URL for potential subsequent tests - db.close_connection() - db.set_connection(url=config.DATABASE_URL) + adb.close_connection() + adb.set_connection_async(url=config.DATABASE_URL) @pytest.fixture(autouse=True, scope="session") @@ -32,38 +31,38 @@ def get_current_database_name() -> str: Returns: - str: The name of the current database. """ - results, meta = db.cypher_query("CALL db.info") + results, meta = adb.cypher_query("CALL db.info") results_as_dict = [dict(zip(meta, row)) for row in results] return results_as_dict[0]["name"] -class Pastry(StructuredNode): +class Pastry(StructuredNodeAsync): name = StringProperty(unique_index=True) def test_set_connection_driver_works(): # Verify that current connection is up - assert Pastry(name="Chocolatine").save() - db.close_connection() + assert Pastry(name="Chocolatine").save_async() + adb.close_connection() # Test connection using a driver - db.set_connection( + adb.set_connection_async( driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) ) - assert Pastry(name="Croissant").save() + assert Pastry(name="Croissant").save_async() def test_config_driver_works(): # Verify that current connection is up - assert Pastry(name="Chausson aux pommes").save() - db.close_connection() + assert Pastry(name="Chausson aux pommes").save_async() + adb.close_connection() # Test connection using a driver defined in config driver = GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) config.DRIVER = driver - assert Pastry(name="Grignette").save() + assert Pastry(name="Grignette").save_async() # Clear config # No need to close connection - pytest teardown will do it @@ -71,31 +70,31 @@ def test_config_driver_works(): @pytest.mark.skipif( - db.database_edition != "enterprise", + adb.database_edition != "enterprise", reason="Skipping test for community edition - no multi database in CE", ) def test_connect_to_non_default_database(): database_name = "pastries" - db.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") - db.close_connection() + adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") + adb.close_connection() # Set database name in url - for url init only - db.set_connection(url=f"{config.DATABASE_URL}/{database_name}") + adb.set_connection_async(url=f"{config.DATABASE_URL}/{database_name}") assert get_current_database_name() == "pastries" - db.close_connection() + adb.close_connection() # Set database name in config - for both url and driver init config.DATABASE_NAME = database_name # url init - db.set_connection(url=config.DATABASE_URL) + adb.set_connection_async(url=config.DATABASE_URL) assert get_current_database_name() == "pastries" - db.close_connection() + adb.close_connection() # driver init - db.set_connection( + adb.set_connection_async( driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) ) assert get_current_database_name() == "pastries" @@ -113,17 +112,17 @@ def test_wrong_url_format(url): ValueError, match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}", ): - db.set_connection(url=url) + adb.set_connection_async(url=url) @pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) def test_connect_to_aura(protocol): cypher_return = "hello world" default_cypher_query = f"RETURN '{cypher_return}'" - db.close_connection() + adb.close_connection() _set_connection(protocol=protocol) - result, _ = db.cypher_query(default_cypher_query) + result, _ = adb.cypher_query(default_cypher_query) assert len(result) > 0 assert result[0][0] == cypher_return @@ -135,4 +134,4 @@ def _set_connection(protocol): AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" - db.set_connection(url=database_url) + adb.set_connection_async(url=database_url) diff --git a/test/test_contrib/test_semi_structured.py b/test/test_contrib/test_semi_structured.py index fe73a2bd..c04530f8 100644 --- a/test/test_contrib/test_semi_structured.py +++ b/test/test_contrib/test_semi_structured.py @@ -13,13 +13,13 @@ class Dummy(SemiStructuredNode): def test_to_save_to_model_with_required_only(): u = UserProf(email="dummy@test.com") - assert u.save() + assert u.save_async() def test_save_to_model_with_extras(): u = UserProf(email="jim@test.com", age=3, bar=99) u.foo = True - assert u.save() + assert u.save_async() u = UserProf.nodes.get(age=3) assert u.foo is True assert u.bar == 99 @@ -27,4 +27,4 @@ def test_save_to_model_with_extras(): def test_save_empty_model(): dummy = Dummy() - assert dummy.save() + assert dummy.save_async() diff --git a/test/test_contrib/test_spatial_properties.py b/test/test_contrib/test_spatial_properties.py index 03177c02..e009527a 100644 --- a/test/test_contrib/test_spatial_properties.py +++ b/test/test_contrib/test_spatial_properties.py @@ -12,8 +12,7 @@ import neomodel import neomodel.contrib.spatial_properties - -from .test_spatial_datatypes import ( +from neomodel.test_spatial_datatypes import ( basic_type_assertions, check_and_skip_neo4j_least_version, ) @@ -167,7 +166,7 @@ def get_some_point(): (random.random(), random.random()) ) - class LocalisableEntity(neomodel.StructuredNode): + class LocalisableEntity(neomodel.StructuredNodeAsync): """ A very simple entity to try out the default value assignment. """ @@ -183,7 +182,7 @@ class LocalisableEntity(neomodel.StructuredNode): ) # Save an object - an_object = LocalisableEntity().save() + an_object = LocalisableEntity().save_async() coords = an_object.location.coords[0] # Retrieve it retrieved_object = LocalisableEntity.nodes.get(identifier=an_object.identifier) @@ -201,7 +200,7 @@ def test_array_of_points(): :return: """ - class AnotherLocalisableEntity(neomodel.StructuredNode): + class AnotherLocalisableEntity(neomodel.StructuredNodeAsync): """ A very simple entity with an array of locations """ @@ -221,7 +220,7 @@ class AnotherLocalisableEntity(neomodel.StructuredNode): neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), ] - ).save() + ).save_async() retrieved_object = AnotherLocalisableEntity.nodes.get( identifier=an_object.identifier @@ -243,7 +242,7 @@ def test_simple_storage_retrieval(): :return: """ - class TestStorageRetrievalProperty(neomodel.StructuredNode): + class TestStorageRetrievalProperty(neomodel.StructuredNodeAsync): uid = neomodel.UniqueIdProperty() description = neomodel.StringProperty() location = neomodel.contrib.spatial_properties.PointProperty(crs="cartesian") @@ -256,7 +255,7 @@ class TestStorageRetrievalProperty(neomodel.StructuredNode): a_restaurant = TestStorageRetrievalProperty( description="Milliways", location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)), - ).save() + ).save_async() a_property = TestStorageRetrievalProperty.nodes.get( location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)) @@ -264,6 +263,7 @@ class TestStorageRetrievalProperty(neomodel.StructuredNode): assert a_restaurant.description == a_property.description + def test_equality_with_other_objects(): """ Performs equality tests and ensures tha ``NeomodelPoint`` can be compared with ShapelyPoint and NeomodelPoint only. @@ -277,6 +277,9 @@ def test_equality_with_other_objects(): if int("".join(__version__.split(".")[0:3])) < 200: pytest.skip(f"Shapely 2.0 not present (Current version is {__version__}") - assert neomodel.contrib.spatial_properties.NeomodelPoint((0,0)) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0) - assert neomodel.contrib.spatial_properties.NeomodelPoint((0,0)) == shapely.geometry.Point((0,0)) - + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0) + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == shapely.geometry.Point((0, 0)) diff --git a/test/test_cypher.py b/test/test_cypher.py index 7c2e6fd6..ef07bdc8 100644 --- a/test/test_cypher.py +++ b/test/test_cypher.py @@ -5,21 +5,21 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNode -from neomodel.core import db +from neomodel import StringProperty, StructuredNodeAsync +from neomodel._async.core import adb -class User2(StructuredNode): +class User2(StructuredNodeAsync): name = StringProperty() email = StringProperty() -class UserPandas(StructuredNode): +class UserPandas(StructuredNodeAsync): name = StringProperty() email = StringProperty() -class UserNP(StructuredNode): +class UserNP(StructuredNodeAsync): name = StringProperty() email = StringProperty() @@ -36,21 +36,22 @@ def mocked_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", mocked_import) -def test_cypher(): +@pytest.mark.asyncio +async def test_cypher_async(): """ test result format is backward compatible with earlier versions of neomodel """ - jim = User2(email="jim1@test.com").save() - data, meta = jim.cypher( - f"MATCH (a) WHERE {db.get_id_method()}(a)=$self RETURN a.email" + jim = await User2(email="jim1@test.com").save_async() + data, meta = await jim.cypher_async( + f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self RETURN a.email" ) assert data[0][0] == "jim1@test.com" assert "a.email" in meta - data, meta = jim.cypher( + data, meta = await jim.cypher_async( f""" - MATCH (a) WHERE {db.get_id_method()}(a)=$self + MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH (a)<-[:USER2]-(b) RETURN a, b, 3 """ @@ -58,10 +59,13 @@ def test_cypher(): assert "a" in meta and "b" in meta -def test_cypher_syntax_error(): - jim = User2(email="jim1@test.com").save() +@pytest.mark.asyncio +async def test_cypher_syntax_error_async(): + jim = await User2(email="jim1@test.com").save_async() try: - jim.cypher(f"MATCH a WHERE {db.get_id_method()}(a)={{self}} RETURN xx") + await jim.cypher_async( + f"MATCH a WHERE {adb.get_id_method()}(a)={{self}} RETURN xx" + ) except CypherError as e: assert hasattr(e, "message") assert hasattr(e, "code") @@ -69,8 +73,9 @@ def test_cypher_syntax_error(): assert False, "CypherError not raised." +@pytest.mark.asyncio @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) -def test_pandas_not_installed(hide_available_pkg): +async def test_pandas_not_installed_async(hide_available_pkg): with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -78,18 +83,23 @@ def test_pandas_not_installed(hide_available_pkg): ): from neomodel.integration.pandas import to_dataframe - _ = to_dataframe(db.cypher_query("MATCH (a) RETURN a.name AS name")) + _ = to_dataframe( + await adb.cypher_query_async("MATCH (a) RETURN a.name AS name") + ) -def test_pandas_integration(): +@pytest.mark.asyncio +async def test_pandas_integration_async(): from neomodel.integration.pandas import to_dataframe, to_series - jimla = UserPandas(email="jimla@test.com", name="jimla").save() - jimlo = UserPandas(email="jimlo@test.com", name="jimlo").save() + jimla = await UserPandas(email="jimla@test.com", name="jimla").save_async() + jimlo = await UserPandas(email="jimlo@test.com", name="jimlo").save_async() # Test to_dataframe df = to_dataframe( - db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email") + await adb.cypher_query_async( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + ) ) assert isinstance(df, DataFrame) @@ -98,7 +108,9 @@ def test_pandas_integration(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( - db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email"), + await adb.cypher_query_async( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + ), index=df["email"], dtype=str, ) @@ -106,15 +118,18 @@ def test_pandas_integration(): assert df.index.inferred_type == "string" # Next test to_series - series = to_series(db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name")) + series = to_series( + await adb.cypher_query_async("MATCH (a:UserPandas) RETURN a.name AS name") + ) assert isinstance(series, Series) assert series.shape == (2,) assert df["name"].tolist() == ["jimla", "jimlo"] +@pytest.mark.asyncio @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) -def test_numpy_not_installed(hide_available_pkg): +async def test_numpy_not_installed_async(hide_available_pkg): with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -122,17 +137,22 @@ def test_numpy_not_installed(hide_available_pkg): ): from neomodel.integration.numpy import to_ndarray - _ = to_ndarray(db.cypher_query("MATCH (a) RETURN a.name AS name")) + _ = to_ndarray( + await adb.cypher_query_async("MATCH (a) RETURN a.name AS name") + ) -def test_numpy_integration(): +@pytest.mark.asyncio +async def test_numpy_integration_async(): from neomodel.integration.numpy import to_ndarray - jimly = UserNP(email="jimly@test.com", name="jimly").save() - jimlu = UserNP(email="jimlu@test.com", name="jimlu").save() + jimly = await UserNP(email="jimly@test.com", name="jimly").save_async() + jimlu = await UserNP(email="jimlu@test.com", name="jimlu").save_async() array = to_ndarray( - db.cypher_query("MATCH (a:UserNP) RETURN a.name AS name, a.email AS email") + await adb.cypher_query_async( + "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" + ) ) assert isinstance(array, ndarray) diff --git a/test/test_database_management.py b/test/test_database_management.py index 2a2ece34..e3f89f19 100644 --- a/test/test_database_management.py +++ b/test/test_database_management.py @@ -5,14 +5,13 @@ IntegerProperty, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, - db, - util, ) +from neomodel._async.core import adb -class City(StructuredNode): +class City(StructuredNodeAsync): name = StringProperty() @@ -20,35 +19,35 @@ class InCity(StructuredRel): creation_year = IntegerProperty(index=True) -class Venue(StructuredNode): +class Venue(StructuredNodeAsync): name = StringProperty(unique_index=True) creator = StringProperty(index=True) in_city = RelationshipTo(City, relation_type="IN", model=InCity) def test_clear_database(): - venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save() - city = City(name="London").save() + venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save_async() + city = City(name="London").save_async() venue.in_city.connect(city) # Clear only the data - util.clear_neo4j_database(db) - database_is_populated, _ = db.cypher_query( + adb.clear_neo4j_database_async() + database_is_populated, _ = adb.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) assert database_is_populated[0][0] is False - indexes = db.list_indexes(exclude_token_lookup=True) - constraints = db.list_constraints() + indexes = adb.lise_indexes_async(exclude_token_lookup=True) + constraints = adb.list_constraints_async() assert len(indexes) > 0 assert len(constraints) > 0 # Clear constraints and indexes too - util.clear_neo4j_database(db, clear_constraints=True, clear_indexes=True) + adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) - indexes = db.list_indexes(exclude_token_lookup=True) - constraints = db.list_constraints() + indexes = adb.lise_indexes_async(exclude_token_lookup=True) + constraints = adb.list_constraints_async() assert len(indexes) == 0 assert len(constraints) == 0 @@ -59,19 +58,19 @@ def test_change_password(): prev_url = f"bolt://neo4j:{prev_password}@localhost:7687" new_url = f"bolt://neo4j:{new_password}@localhost:7687" - util.change_neo4j_password(db, "neo4j", new_password) - db.close_connection() + adb.change_neo4j_password_async("neo4j", new_password) + adb.close_connection() - db.set_connection(url=new_url) - db.close_connection() + adb.set_connection_async(url=new_url) + adb.close_connection() with pytest.raises(AuthError): - db.set_connection(url=prev_url) + adb.set_connection_async(url=prev_url) - db.close_connection() + adb.close_connection() - db.set_connection(url=new_url) - util.change_neo4j_password(db, "neo4j", prev_password) - db.close_connection() + adb.set_connection_async(url=new_url) + adb.change_neo4j_password_async("neo4j", prev_password) + adb.close_connection() - db.set_connection(url=prev_url) + adb.set_connection_async(url=prev_url) diff --git a/test/test_dbms_awareness.py b/test/test_dbms_awareness.py index fa041afe..0927eb4c 100644 --- a/test/test_dbms_awareness.py +++ b/test/test_dbms_awareness.py @@ -1,22 +1,22 @@ from pytest import mark -from neomodel import db +from neomodel._async.core import adb @mark.skipif( - db.database_version != "5.7.0", reason="Testing a specific database version" + adb.database_version != "5.7.0", reason="Testing a specific database version" ) def test_version_awareness(): - assert db.database_version == "5.7.0" - assert db.version_is_higher_than("5.7") - assert db.version_is_higher_than("5") - assert db.version_is_higher_than("4") + assert adb.database_version == "5.7.0" + assert adb.version_is_higher_than("5.7") + assert adb.version_is_higher_than("5") + assert adb.version_is_higher_than("4") - assert not db.version_is_higher_than("5.8") + assert not adb.version_is_higher_than("5.8") def test_edition_awareness(): - if db.database_edition == "enterprise": - assert db.edition_is_enterprise() + if adb.database_edition == "enterprise": + assert adb.edition_is_enterprise() else: - assert not db.edition_is_enterprise() + assert not adb.edition_is_enterprise() diff --git a/test/test_driver_options.py b/test/test_driver_options.py index 26f16640..d896d257 100644 --- a/test/test_driver_options.py +++ b/test/test_driver_options.py @@ -2,49 +2,49 @@ from neo4j.exceptions import ClientError from pytest import raises -from neomodel import db +from neomodel._async.core import adb from neomodel.exceptions import FeatureNotSupported @pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" + not adb.edition_is_enterprise(), reason="Skipping test for community edition" ) def test_impersonate(): - with db.impersonate(user="troygreene"): - results, _ = db.cypher_query("RETURN 'Doo Wacko !'") + with adb.impersonate(user="troygreene"): + results, _ = adb.cypher_query_async("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" @pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" + not adb.edition_is_enterprise(), reason="Skipping test for community edition" ) def test_impersonate_unauthorized(): - with db.impersonate(user="unknownuser"): + with adb.impersonate(user="unknownuser"): with raises(ClientError): - _ = db.cypher_query("RETURN 'Gabagool'") + _ = adb.cypher_query_async("RETURN 'Gabagool'") @pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" + not adb.edition_is_enterprise(), reason="Skipping test for community edition" ) def test_impersonate_multiple_transactions(): - with db.impersonate(user="troygreene"): - with db.transaction: - results, _ = db.cypher_query("RETURN 'Doo Wacko !'") + with adb.impersonate(user="troygreene"): + with adb.transaction: + results, _ = adb.cypher_query_async("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" - with db.transaction: - results, _ = db.cypher_query("SHOW CURRENT USER") + with adb.transaction: + results, _ = adb.cypher_query_async("SHOW CURRENT USER") assert results[0][0] == "troygreene" - results, _ = db.cypher_query("SHOW CURRENT USER") + results, _ = adb.cypher_query_async("SHOW CURRENT USER") assert results[0][0] == "neo4j" @pytest.mark.skipif( - db.edition_is_enterprise(), reason="Skipping test for enterprise edition" + adb.edition_is_enterprise(), reason="Skipping test for enterprise edition" ) def test_impersonate_community(): with raises(FeatureNotSupported): - with db.impersonate(user="troygreene"): - _ = db.cypher_query("RETURN 'Gabagoogoo'") + with adb.impersonate(user="troygreene"): + _ = adb.cypher_query_async("RETURN 'Gabagoogoo'") diff --git a/test/test_exceptions.py b/test/test_exceptions.py index 546c13fe..f631fa4b 100644 --- a/test/test_exceptions.py +++ b/test/test_exceptions.py @@ -1,9 +1,9 @@ import pickle -from neomodel import DoesNotExist, StringProperty, StructuredNode +from neomodel import DoesNotExist, StringProperty, StructuredNodeAsync -class EPerson(StructuredNode): +class EPerson(StructuredNodeAsync): name = StringProperty(unique_index=True) diff --git a/test/test_hooks.py b/test/test_hooks.py index 158db079..06c49247 100644 --- a/test/test_hooks.py +++ b/test/test_hooks.py @@ -1,9 +1,9 @@ -from neomodel import StringProperty, StructuredNode +from neomodel import StringProperty, StructuredNodeAsync HOOKS_CALLED = {} -class HookTest(StructuredNode): +class HookTest(StructuredNodeAsync): name = StringProperty() def post_create(self): @@ -23,8 +23,8 @@ def post_delete(self): def test_hooks(): - ht = HookTest(name="k").save() - ht.delete() + ht = HookTest(name="k").save_async() + ht.delete_async() assert "pre_save" in HOOKS_CALLED assert "post_save" in HOOKS_CALLED assert "post_create" in HOOKS_CALLED diff --git a/test/test_indexing.py b/test/test_indexing.py index 0b5e8fba..9cf90b60 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -4,24 +4,23 @@ from neomodel import ( IntegerProperty, StringProperty, - StructuredNode, + StructuredNodeAsync, UniqueProperty, - install_labels, ) -from neomodel.core import db +from neomodel._async.core import adb from neomodel.exceptions import ConstraintValidationFailed -class Human(StructuredNode): +class Human(StructuredNodeAsync): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) def test_unique_error(): - install_labels(Human) - Human(name="j1m", age=13).save() + adb.install_labels_async(Human) + Human(name="j1m", age=13).save_async() try: - Human(name="j1m", age=14).save() + Human(name="j1m", age=14).save_async() except UniqueProperty as e: assert str(e).find("j1m") assert str(e).find("name") @@ -30,25 +29,25 @@ def test_unique_error(): @pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" + not adb.edition_is_enterprise(), reason="Skipping test for community edition" ) def test_existence_constraint_error(): - db.cypher_query( + adb.cypher_query_async( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" ) with raises(ConstraintValidationFailed, match=r"must have the property"): - Human(name="Scarlett").save() + Human(name="Scarlett").save_async() - db.cypher_query("DROP CONSTRAINT test_existence_constraint") + adb.cypher_query_async("DROP CONSTRAINT test_existence_constraint") def test_optional_properties_dont_get_indexed(): - Human(name="99", age=99).save() + Human(name="99", age=99).save_async() h = Human.nodes.get(age=99) assert h assert h.name == "99" - Human(name="98", age=98).save() + Human(name="98", age=98).save_async() h = Human.nodes.get(age=98) assert h assert h.name == "98" @@ -56,7 +55,7 @@ def test_optional_properties_dont_get_indexed(): def test_escaped_chars(): _name = "sarah:test" - Human(name=_name, age=3).save() + Human(name=_name, age=3).save_async() r = Human.nodes.filter(name=_name) assert r assert r[0].name == _name @@ -68,11 +67,11 @@ def test_does_not_exist(): def test_custom_label_name(): - class Giraffe(StructuredNode): + class Giraffe(StructuredNodeAsync): __label__ = "Giraffes" name = StringProperty(unique_index=True) - jim = Giraffe(name="timothy").save() + jim = Giraffe(name="timothy").save_async() node = Giraffe.nodes.get(name="timothy") assert node.name == jim.name diff --git a/test/test_issue112.py b/test/test_issue112.py index 3e379932..295fe239 100644 --- a/test/test_issue112.py +++ b/test/test_issue112.py @@ -1,13 +1,13 @@ -from neomodel import RelationshipTo, StructuredNode +from neomodel import RelationshipTo, StructuredNodeAsync -class SomeModel(StructuredNode): +class SomeModel(StructuredNodeAsync): test = RelationshipTo("SomeModel", "SELF") def test_len_relationship(): - t1 = SomeModel().save() - t2 = SomeModel().save() + t1 = SomeModel().save_async() + t2 = SomeModel().save_async() t1.test.connect(t2) l = len(t1.test.all()) diff --git a/test/test_issue283.py b/test/test_issue283.py index b7a63f8f..8f4bb29a 100644 --- a/test/test_issue283.py +++ b/test/test_issue283.py @@ -36,7 +36,7 @@ class PersonalRelationship(neomodel.StructuredRel): on_date = neomodel.DateTimeProperty(default_now=True) -class BasePerson(neomodel.StructuredNode): +class BasePerson(neomodel.StructuredNodeAsync): """ Base class for defining some basic sort of an actor. """ @@ -64,7 +64,7 @@ class PilotPerson(BasePerson): airplane = neomodel.StringProperty(required=True) -class BaseOtherPerson(neomodel.StructuredNode): +class BaseOtherPerson(neomodel.StructuredNodeAsync): """ An obviously "wrong" class of actor to befriend BasePersons with. """ @@ -88,15 +88,15 @@ def test_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happy", "expertise": "Unicorns"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepy", "expertise": "Pillows"} - )[0] + B = TechnicalPerson.get_or_create_async({"name": "Happy", "expertise": "Unicorns"})[ + 0 + ] + C = TechnicalPerson.get_or_create_async({"name": "Sleepy", "expertise": "Pillows"})[ + 0 + ] # Add connections A.friends_with.connect(B) @@ -107,9 +107,9 @@ def test_automatic_result_resolution(): # TechnicalPerson (!NOT basePerson!) assert type(A.friends_with[0]) is TechnicalPerson - A.delete() - B.delete() - C.delete() + A.delete_async() + B.delete_async() + C.delete_async() def test_recursive_automatic_result_resolution(): @@ -120,21 +120,21 @@ def test_recursive_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpier", "expertise": "Grumpiness"} )[0] - B = TechnicalPerson.get_or_create( + B = TechnicalPerson.get_or_create_async( {"name": "Happier", "expertise": "Grumpiness"} )[0] - C = TechnicalPerson.get_or_create( + C = TechnicalPerson.get_or_create_async( {"name": "Sleepier", "expertise": "Pillows"} )[0] - D = TechnicalPerson.get_or_create( + D = TechnicalPerson.get_or_create_async( {"name": "Sneezier", "expertise": "Pillows"} )[0] # Retrieve mixed results, both at the top level and nested - L, _ = neomodel.db.cypher_query( + L, _ = neomodel.adb.cypher_query( "MATCH (a:TechnicalPerson) " "WHERE a.expertise='Grumpiness' " "WITH collect(a) as Alpha " @@ -152,10 +152,10 @@ def test_recursive_automatic_result_resolution(): # Assert that primitive data types remain primitive data types assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - A.delete() - B.delete() - C.delete() - D.delete() + A.delete_async() + B.delete_async() + C.delete_async() + D.delete_async() def test_validation_with_inheritance_from_db(): @@ -166,21 +166,21 @@ def test_validation_with_inheritance_from_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happy", "expertise": "Unicorns"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepy", "expertise": "Pillows"} - )[0] + B = TechnicalPerson.get_or_create_async({"name": "Happy", "expertise": "Unicorns"})[ + 0 + ] + C = TechnicalPerson.get_or_create_async({"name": "Sleepy", "expertise": "Pillows"})[ + 0 + ] # Pilot Persons - D = PilotPerson.get_or_create( + D = PilotPerson.get_or_create_async( {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} )[0] - E = PilotPerson.get_or_create( + E = PilotPerson.get_or_create_async( {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} )[0] @@ -209,11 +209,11 @@ def test_validation_with_inheritance_from_db(): ) assert type(D.friends_with[0]) is PilotPerson - A.delete() - B.delete() - C.delete() - D.delete() - E.delete() + A.delete_async() + B.delete_async() + C.delete_async() + D.delete_async() + E.delete_async() def test_validation_enforcement_to_db(): @@ -223,26 +223,26 @@ def test_validation_enforcement_to_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happy", "expertise": "Unicorns"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepy", "expertise": "Pillows"} - )[0] + B = TechnicalPerson.get_or_create_async({"name": "Happy", "expertise": "Unicorns"})[ + 0 + ] + C = TechnicalPerson.get_or_create_async({"name": "Sleepy", "expertise": "Pillows"})[ + 0 + ] # Pilot Persons - D = PilotPerson.get_or_create( + D = PilotPerson.get_or_create_async( {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} )[0] - E = PilotPerson.get_or_create( + E = PilotPerson.get_or_create_async( {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} )[0] # Some Person - F = SomePerson(car_color="Blue").save() + F = SomePerson(car_color="Blue").save_async() # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine A.friends_with.connect(B) @@ -257,12 +257,12 @@ def test_validation_enforcement_to_db(): with pytest.raises(ValueError): A.friends_with.connect(F) - A.delete() - B.delete() - C.delete() - D.delete() - E.delete() - F.delete() + A.delete_async() + B.delete_async() + C.delete_async() + D.delete_async() + E.delete_async() + F.delete_async() def test_failed_result_resolution(): @@ -276,23 +276,21 @@ class RandomPerson(BasePerson): randomness = neomodel.FloatProperty(default=random.random) # A Technical Person... - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] # A Random Person... - B = RandomPerson.get_or_create({"name": "Mad Hatter"})[0] + B = RandomPerson.get_or_create_async({"name": "Mad Hatter"})[0] A.friends_with.connect(B) # Simulate the condition where the definition of class RandomPerson is not # known yet. - del neomodel.db._NODE_CLASS_REGISTRY[ - frozenset(["RandomPerson", "BasePerson"]) - ] + del neomodel.adb._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] # Now try to instantiate a RandomPerson - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] with pytest.raises( @@ -302,8 +300,8 @@ class RandomPerson(BasePerson): for some_friend in A.friends_with: print(some_friend.name) - A.delete() - B.delete() + A.delete_async() + B.delete_async() def test_node_label_mismatch(): @@ -319,17 +317,17 @@ class UltraTechnicalPerson(SuperTechnicalPerson): ultraness = neomodel.FloatProperty(default=3.1415928) # Create a TechnicalPerson... - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] # ...that is connected to an UltraTechnicalPerson F = UltraTechnicalPerson( name="Chewbaka", expertise="Aarrr wgh ggwaaah" - ).save() + ).save_async() A.friends_with.connect(F) # Forget about the UltraTechnicalPerson - del neomodel.db._NODE_CLASS_REGISTRY[ + del neomodel.adb._NODE_CLASS_REGISTRY[ frozenset( [ "UltraTechnicalPerson", @@ -343,7 +341,7 @@ class UltraTechnicalPerson(SuperTechnicalPerson): # Recall a TechnicalPerson and enumerate its friends. # One of them is UltraTechnicalPerson which would be returned as a valid # node to a friends_with query but is currently unknown to the node class registry. - A = TechnicalPerson.get_or_create( + A = TechnicalPerson.get_or_create_async( {"name": "Grumpy", "expertise": "Grumpiness"} )[0] with pytest.raises(neomodel.exceptions.NodeClassNotDefined): @@ -375,20 +373,18 @@ def test_relationship_result_resolution(): A query returning a "Relationship" object can now instantiate it to a data model class """ # Test specific data - A = PilotPerson( - name="Zantford Granville", airplane="Gee Bee Model R" - ).save() - B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save() - C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save() - D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save() - E = PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save() + A = PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save_async() + B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save_async() + C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save_async() + D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save_async() + E = PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save_async() A.friends_with.connect(B) B.friends_with.connect(C) C.friends_with.connect(D) D.friends_with.connect(E) - query_data = neomodel.db.cypher_query( + query_data = neomodel.adb.cypher_query( "MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) " "WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' " "RETURN DISTINCT r", @@ -419,14 +415,14 @@ class ExtendedSomePerson(SomePerson): ) # Test specific data - A = ExtendedSomePerson(name="Michael Knight", car_color="Black").save() - B = ExtendedSomePerson(name="Luke Duke", car_color="Orange").save() - C = ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save() + A = ExtendedSomePerson(name="Michael Knight", car_color="Black").save_async() + B = ExtendedSomePerson(name="Luke Duke", car_color="Orange").save_async() + C = ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save_async() A.friends_with.connect(B) A.friends_with.connect(C) - query_data = neomodel.db.cypher_query( + query_data = neomodel.adb.cypher_query( "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " "RETURN DISTINCT r", resolve_objects=True, @@ -462,13 +458,13 @@ def test_resolve_inexistent_relationship(): """ # Forget about the FRIENDS_WITH Relationship. - del neomodel.db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] + del neomodel.adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] with pytest.raises( neomodel.RelationshipClassNotDefined, match=r"Relationship of type .* does not resolve to any of the known objects.*", ): - query_data = neomodel.db.cypher_query( + query_data = neomodel.adb.cypher_query( "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " "RETURN DISTINCT r", resolve_objects=True, diff --git a/test/test_issue600.py b/test/test_issue600.py index 6851efd2..5d760bf0 100644 --- a/test/test_issue600.py +++ b/test/test_issue600.py @@ -30,7 +30,7 @@ class SubClass2(Class1): pass -class RelationshipDefinerSecondSibling(neomodel.StructuredNode): +class RelationshipDefinerSecondSibling(neomodel.StructuredNodeAsync): rel_1 = neomodel.Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 ) @@ -42,7 +42,7 @@ class RelationshipDefinerSecondSibling(neomodel.StructuredNode): ) -class RelationshipDefinerParentLast(neomodel.StructuredNode): +class RelationshipDefinerParentLast(neomodel.StructuredNodeAsync): rel_2 = neomodel.Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 ) @@ -57,9 +57,9 @@ class RelationshipDefinerParentLast(neomodel.StructuredNode): # Test cases def test_relationship_definer_second_sibling(): # Create a few entities - A = RelationshipDefinerSecondSibling.get_or_create({})[0] - B = RelationshipDefinerSecondSibling.get_or_create({})[0] - C = RelationshipDefinerSecondSibling.get_or_create({})[0] + A = RelationshipDefinerSecondSibling.get_or_create_async({})[0] + B = RelationshipDefinerSecondSibling.get_or_create_async({})[0] + C = RelationshipDefinerSecondSibling.get_or_create_async({})[0] # Add connections A.rel_1.connect(B) @@ -67,16 +67,16 @@ def test_relationship_definer_second_sibling(): C.rel_3.connect(A) # Clean up - A.delete() - B.delete() - C.delete() + A.delete_async() + B.delete_async() + C.delete_async() def test_relationship_definer_parent_last(): # Create a few entities - A = RelationshipDefinerParentLast.get_or_create({})[0] - B = RelationshipDefinerParentLast.get_or_create({})[0] - C = RelationshipDefinerParentLast.get_or_create({})[0] + A = RelationshipDefinerParentLast.get_or_create_async({})[0] + B = RelationshipDefinerParentLast.get_or_create_async({})[0] + C = RelationshipDefinerParentLast.get_or_create_async({})[0] # Add connections A.rel_1.connect(B) @@ -84,6 +84,6 @@ def test_relationship_definer_parent_last(): C.rel_3.connect(A) # Clean up - A.delete() - B.delete() - C.delete() + A.delete_async() + B.delete_async() + C.delete_async() diff --git a/test/test_label_drop.py b/test/test_label_drop.py index 389d19e0..5d3dc13a 100644 --- a/test/test_label_drop.py +++ b/test/test_label_drop.py @@ -1,27 +1,27 @@ from neo4j.exceptions import ClientError -from neomodel import StringProperty, StructuredNode, config -from neomodel.core import db, remove_all_labels +from neomodel import StringProperty, StructuredNodeAsync, config +from neomodel._async.core import adb config.AUTO_INSTALL_LABELS = True -class ConstraintAndIndex(StructuredNode): +class ConstraintAndIndex(StructuredNodeAsync): name = StringProperty(unique_index=True) last_name = StringProperty(index=True) def test_drop_labels(): - constraints_before = db.list_constraints() - indexes_before = db.list_indexes(exclude_token_lookup=True) + constraints_before = adb.list_constraints_async() + indexes_before = adb.list_indexes_async(exclude_token_lookup=True) assert len(constraints_before) > 0 assert len(indexes_before) > 0 - remove_all_labels() + adb.remove_all_labels_async() - constraints = db.list_constraints() - indexes = db.list_indexes(exclude_token_lookup=True) + constraints = adb.list_constraints_async() + indexes = adb.list_indexes_async(exclude_token_lookup=True) assert len(constraints) == 0 assert len(indexes) == 0 @@ -34,12 +34,12 @@ def test_drop_labels(): elif constraint["type"] == "NODE_KEY": constraint_type_clause = "NODE KEY" - db.cypher_query( + adb.cypher_query_async( f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}' ) for index in indexes_before: try: - db.cypher_query( + adb.cypher_query_async( f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})' ) except ClientError: diff --git a/test/test_label_install.py b/test/test_label_install.py index 46f55467..7e1c3dc6 100644 --- a/test/test_label_install.py +++ b/test/test_label_install.py @@ -3,28 +3,26 @@ from neomodel import ( RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, UniqueIdProperty, config, - install_all_labels, - install_labels, ) -from neomodel.core import db, drop_constraints +from neomodel._async.core import adb from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported config.AUTO_INSTALL_LABELS = False -class NodeWithIndex(StructuredNode): +class NodeWithIndex(StructuredNodeAsync): name = StringProperty(index=True) -class NodeWithConstraint(StructuredNode): +class NodeWithConstraint(StructuredNodeAsync): name = StringProperty(unique_index=True) -class NodeWithRelationship(StructuredNode): +class NodeWithRelationship(StructuredNodeAsync): ... @@ -32,18 +30,18 @@ class IndexedRelationship(StructuredRel): indexed_rel_prop = StringProperty(index=True) -class OtherNodeWithRelationship(StructuredNode): +class OtherNodeWithRelationship(StructuredNodeAsync): has_rel = RelationshipTo( NodeWithRelationship, "INDEXED_REL", model=IndexedRelationship ) -class AbstractNode(StructuredNode): +class AbstractNode(StructuredNodeAsync): __abstract_node__ = True name = StringProperty(unique_index=True) -class SomeNotUniqueNode(StructuredNode): +class SomeNotUniqueNode(StructuredNodeAsync): id_ = UniqueIdProperty(db_property="id") @@ -51,9 +49,9 @@ class SomeNotUniqueNode(StructuredNode): def test_labels_were_not_installed(): - bob = NodeWithConstraint(name="bob").save() - bob2 = NodeWithConstraint(name="bob").save() - bob3 = NodeWithConstraint(name="bob").save() + bob = NodeWithConstraint(name="bob").save_async() + bob2 = NodeWithConstraint(name="bob").save_async() + bob3 = NodeWithConstraint(name="bob").save_async() assert bob.element_id != bob3.element_id for n in NodeWithConstraint.nodes.all(): @@ -61,16 +59,16 @@ def test_labels_were_not_installed(): def test_install_all(): - drop_constraints() - install_labels(AbstractNode) + adb.drop_constraints_async() + adb.install_labels_async(AbstractNode) # run install all labels - install_all_labels() + adb.install_all_labels_async() - indexes = db.list_indexes() + indexes = adb.list_indexes_async() index_names = [index["name"] for index in indexes] assert "index_INDEXED_REL_indexed_rel_prop" in index_names - constraints = db.list_constraints() + constraints = adb.list_constraints_async() constraint_names = [constraint["name"] for constraint in constraints] assert "constraint_unique_NodeWithConstraint_name" in constraint_names assert "constraint_unique_SomeNotUniqueNode_id" in constraint_names @@ -83,43 +81,43 @@ def test_install_label_twice(capsys): expected_std_out = ( "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" ) - install_labels(AbstractNode) - install_labels(AbstractNode) + adb.install_labels_async(AbstractNode) + adb.install_labels_async(AbstractNode) - install_labels(NodeWithIndex) - install_labels(NodeWithIndex, quiet=False) + adb.install_labels_async(NodeWithIndex) + adb.install_labels_async(NodeWithIndex, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - install_labels(NodeWithConstraint) - install_labels(NodeWithConstraint, quiet=False) + adb.install_labels_async(NodeWithConstraint) + adb.install_labels_async(NodeWithConstraint, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - install_labels(OtherNodeWithRelationship) - install_labels(OtherNodeWithRelationship, quiet=False) + adb.install_labels_async(OtherNodeWithRelationship) + adb.install_labels_async(OtherNodeWithRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - if db.version_is_higher_than("5.7"): + if adb.version_is_higher_than("5.7"): class UniqueIndexRelationship(StructuredRel): unique_index_rel_prop = StringProperty(unique_index=True) - class OtherNodeWithUniqueIndexRelationship(StructuredNode): + class OtherNodeWithUniqueIndexRelationship(StructuredNodeAsync): has_rel = RelationshipTo( NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship ) - install_labels(OtherNodeWithUniqueIndexRelationship) - install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) + adb.install_labels_async(OtherNodeWithUniqueIndexRelationship) + adb.install_labels_async(OtherNodeWithUniqueIndexRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out def test_install_labels_db_property(capsys): - drop_constraints() - install_labels(SomeNotUniqueNode, quiet=False) + adb.drop_constraints_async() + adb.install_labels_async(SomeNotUniqueNode, quiet=False) captured = capsys.readouterr() assert "id" in captured.out # make sure that the id_ constraint doesn't exist @@ -131,19 +129,21 @@ def test_install_labels_db_property(capsys): _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") -@pytest.mark.skipif(db.version_is_higher_than("5.7"), reason="Not supported before 5.7") +@pytest.mark.skipif( + adb.version_is_higher_than("5.7"), reason="Not supported before 5.7" +) def test_relationship_unique_index_not_supported(): class UniqueIndexRelationship(StructuredRel): name = StringProperty(unique_index=True) - class TargetNodeForUniqueIndexRelationship(StructuredNode): + class TargetNodeForUniqueIndexRelationship(StructuredNodeAsync): pass with pytest.raises( FeatureNotSupported, match=r".*Please upgrade to Neo4j 5.7 or higher" ): - class NodeWithUniqueIndexRelationship(StructuredNode): + class NodeWithUniqueIndexRelationship(StructuredNodeAsync): has_rel = RelationshipTo( TargetNodeForUniqueIndexRelationship, "UNIQUE_INDEX_REL", @@ -151,25 +151,25 @@ class NodeWithUniqueIndexRelationship(StructuredNode): ) -@pytest.mark.skipif(not db.version_is_higher_than("5.7"), reason="Supported from 5.7") +@pytest.mark.skipif(not adb.version_is_higher_than("5.7"), reason="Supported from 5.7") def test_relationship_unique_index(): class UniqueIndexRelationshipBis(StructuredRel): name = StringProperty(unique_index=True) - class TargetNodeForUniqueIndexRelationship(StructuredNode): + class TargetNodeForUniqueIndexRelationship(StructuredNodeAsync): pass - class NodeWithUniqueIndexRelationship(StructuredNode): + class NodeWithUniqueIndexRelationship(StructuredNodeAsync): has_rel = RelationshipTo( TargetNodeForUniqueIndexRelationship, "UNIQUE_INDEX_REL_BIS", model=UniqueIndexRelationshipBis, ) - install_labels(UniqueIndexRelationshipBis) - node1 = NodeWithUniqueIndexRelationship().save() - node2 = TargetNodeForUniqueIndexRelationship().save() - node3 = TargetNodeForUniqueIndexRelationship().save() + adb.install_labels_async(UniqueIndexRelationshipBis) + node1 = NodeWithUniqueIndexRelationship().save_async() + node2 = TargetNodeForUniqueIndexRelationship().save_async() + node3 = TargetNodeForUniqueIndexRelationship().save_async() rel1 = node1.has_rel.connect(node2, {"name": "rel1"}) with pytest.raises( @@ -180,7 +180,7 @@ class NodeWithUniqueIndexRelationship(StructuredNode): def _drop_constraints_for_label_and_property(label: str = None, property: str = None): - results, meta = db.cypher_query("SHOW CONSTRAINTS") + results, meta = adb.cypher_query_async("SHOW CONSTRAINTS") results_as_dict = [dict(zip(meta, row)) for row in results] constraint_names = [ constraint @@ -188,6 +188,6 @@ def _drop_constraints_for_label_and_property(label: str = None, property: str = if constraint["labelsOrTypes"] == label and constraint["properties"] == property ] for constraint_name in constraint_names: - db.cypher_query(f"DROP CONSTRAINT {constraint_name}") + adb.cypher_query_async(f"DROP CONSTRAINT {constraint_name}") return constraint_names diff --git a/test/test_match_api.py b/test/test_match_api.py index 43c7a104..f5b43c90 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -10,7 +10,7 @@ RelationshipFrom, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, ) from neomodel.exceptions import MultipleNodesReturned @@ -22,18 +22,18 @@ class SupplierRel(StructuredRel): courier = StringProperty() -class Supplier(StructuredNode): +class Supplier(StructuredNodeAsync): name = StringProperty() delivery_cost = IntegerProperty() coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS") -class Species(StructuredNode): +class Species(StructuredNodeAsync): name = StringProperty() coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) -class Coffee(StructuredNode): +class Coffee(StructuredNodeAsync): name = StringProperty(unique_index=True) price = IntegerProperty() suppliers = RelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) @@ -41,12 +41,12 @@ class Coffee(StructuredNode): id_ = IntegerProperty() -class Extension(StructuredNode): +class Extension(StructuredNodeAsync): extension = RelationshipTo("Extension", "extension") def test_filter_exclude_via_labels(): - Coffee(name="Java", price=99).save() + Coffee(name="Java", price=99).save_async() node_set = NodeSet(Coffee) qb = QueryBuilder(node_set).build_ast() @@ -60,7 +60,7 @@ def test_filter_exclude_via_labels(): assert results[0].name == "Java" # with filter and exclude - Coffee(name="Kenco", price=3).save() + Coffee(name="Kenco", price=3).save_async() node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") qb = QueryBuilder(node_set).build_ast() @@ -72,8 +72,8 @@ def test_filter_exclude_via_labels(): def test_simple_has_via_label(): - nescafe = Coffee(name="Nescafe", price=99).save() - tesco = Supplier(name="Tesco", delivery_cost=2).save() + nescafe = Coffee(name="Nescafe", price=99).save_async() + tesco = Supplier(name="Tesco", delivery_cost=2).save_async() nescafe.suppliers.connect(tesco) ns = NodeSet(Coffee).has(suppliers=True) @@ -83,7 +83,7 @@ def test_simple_has_via_label(): assert len(results) == 1 assert results[0].name == "Nescafe" - Coffee(name="nespresso", price=99).save() + Coffee(name="nespresso", price=99).save_async() ns = NodeSet(Coffee).has(suppliers=False) qb = QueryBuilder(ns).build_ast() results = qb._execute() @@ -92,21 +92,21 @@ def test_simple_has_via_label(): def test_get(): - Coffee(name="1", price=3).save() + Coffee(name="1", price=3).save_async() assert Coffee.nodes.get(name="1") with raises(Coffee.DoesNotExist): Coffee.nodes.get(name="2") - Coffee(name="2", price=3).save() + Coffee(name="2", price=3).save_async() with raises(MultipleNodesReturned): Coffee.nodes.get(price=3) def test_simple_traverse_with_filter(): - nescafe = Coffee(name="Nescafe2", price=99).save() - tesco = Supplier(name="Sainsburys", delivery_cost=2).save() + nescafe = Coffee(name="Nescafe2", price=99).save_async() + tesco = Supplier(name="Sainsburys", delivery_cost=2).save_async() nescafe.suppliers.connect(tesco) qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) @@ -121,10 +121,10 @@ def test_simple_traverse_with_filter(): def test_double_traverse(): - nescafe = Coffee(name="Nescafe plus", price=99).save() - tesco = Supplier(name="Asda", delivery_cost=2).save() + nescafe = Coffee(name="Nescafe plus", price=99).save_async() + tesco = Supplier(name="Asda", delivery_cost=2).save_async() nescafe.suppliers.connect(tesco) - tesco.coffees.connect(Coffee(name="Decafe", price=2).save()) + tesco.coffees.connect(Coffee(name="Decafe", price=2).save_async()) ns = NodeSet(NodeSet(source=nescafe).suppliers.match()).coffees.match() qb = QueryBuilder(ns).build_ast() @@ -136,7 +136,7 @@ def test_double_traverse(): def test_count(): - Coffee(name="Nescafe Gold", price=99).save() + Coffee(name="Nescafe Gold", price=99).save_async() count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count() assert count > 0 @@ -144,7 +144,7 @@ def test_count(): def test_len_and_iter_and_bool(): iterations = 0 - Coffee(name="Icelands finest").save() + Coffee(name="Icelands finest").save_async() for c in Coffee.nodes: iterations += 1 @@ -159,9 +159,9 @@ def test_slice(): for c in Coffee.nodes: c.delete() - Coffee(name="Icelands finest").save() - Coffee(name="Britains finest").save() - Coffee(name="Japans finest").save() + Coffee(name="Icelands finest").save_async() + Coffee(name="Britains finest").save_async() + Coffee(name="Japans finest").save_async() assert len(list(Coffee.nodes.all()[1:])) == 2 assert len(list(Coffee.nodes.all()[:1])) == 1 @@ -173,9 +173,9 @@ def test_slice(): def test_issue_208(): # calls to match persist across queries. - b = Coffee(name="basics").save() - l = Supplier(name="lidl").save() - a = Supplier(name="aldi").save() + b = Coffee(name="basics").save_async() + l = Supplier(name="lidl").save_async() + a = Supplier(name="aldi").save_async() b.suppliers.connect(l, {"courier": "fedex"}) b.suppliers.connect(a, {"courier": "dhl"}) @@ -185,15 +185,15 @@ def test_issue_208(): def test_issue_589(): - node1 = Extension().save() - node2 = Extension().save() + node1 = Extension().save_async() + node2 = Extension().save_async() node1.extension.connect(node2) assert node2 in node1.extension.all() def test_contains(): - expensive = Coffee(price=1000, name="Pricey").save() - asda = Coffee(name="Asda", price=1).save() + expensive = Coffee(price=1000, name="Pricey").save_async() + asda = Coffee(name="Asda", price=1).save_async() assert expensive in Coffee.nodes.filter(price__gt=999) assert asda not in Coffee.nodes.filter(price__gt=999) @@ -211,9 +211,9 @@ def test_order_by(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Icelands finest", price=5).save() - c2 = Coffee(name="Britains finest", price=10).save() - c3 = Coffee(name="Japans finest", price=35).save() + c1 = Coffee(name="Icelands finest", price=5).save_async() + c2 = Coffee(name="Britains finest", price=10).save_async() + c3 = Coffee(name="Japans finest", price=35).save_async() assert Coffee.nodes.order_by("price").all()[0].price == 5 assert Coffee.nodes.order_by("-price").all()[0].price == 35 @@ -236,7 +236,7 @@ def test_order_by(): Coffee.nodes.order_by("id") # Test order by on a relationship - l = Supplier(name="lidl2").save() + l = Supplier(name="lidl2").save_async() l.coffees.connect(c1) l.coffees.connect(c2) l.coffees.connect(c3) @@ -251,10 +251,10 @@ def test_extra_filters(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Icelands finest", price=5, id_=1).save() - c2 = Coffee(name="Britains finest", price=10, id_=2).save() - c3 = Coffee(name="Japans finest", price=35, id_=3).save() - c4 = Coffee(name="US extra-fine", price=None, id_=4).save() + c1 = Coffee(name="Icelands finest", price=5, id_=1).save_async() + c2 = Coffee(name="Britains finest", price=10, id_=2).save_async() + c3 = Coffee(name="Japans finest", price=35, id_=3).save_async() + c4 = Coffee(name="US extra-fine", price=None, id_=4).save_async() coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]).all() assert len(coffees_5_10) == 2, "unexpected number of results" @@ -325,8 +325,8 @@ def test_empty_filters(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Super", price=5, id_=1).save() - c2 = Coffee(name="Puper", price=10, id_=2).save() + c1 = Coffee(name="Super", price=5, id_=1).save_async() + c2 = Coffee(name="Puper", price=10, id_=2).save_async() empty_filter = Coffee.nodes.filter() @@ -351,12 +351,12 @@ def test_q_filters(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Icelands finest", price=5, id_=1).save() - c2 = Coffee(name="Britains finest", price=10, id_=2).save() - c3 = Coffee(name="Japans finest", price=35, id_=3).save() - c4 = Coffee(name="US extra-fine", price=None, id_=4).save() - c5 = Coffee(name="Latte", price=35, id_=5).save() - c6 = Coffee(name="Cappuccino", price=35, id_=6).save() + c1 = Coffee(name="Icelands finest", price=5, id_=1).save_async() + c2 = Coffee(name="Britains finest", price=10, id_=2).save_async() + c3 = Coffee(name="Japans finest", price=35, id_=3).save_async() + c4 = Coffee(name="US extra-fine", price=None, id_=4).save_async() + c5 = Coffee(name="Latte", price=35, id_=5).save_async() + c6 = Coffee(name="Cappuccino", price=35, id_=6).save_async() coffees_5_10 = Coffee.nodes.filter(Q(price=10) | Q(price=5)).all() assert len(coffees_5_10) == 2, "unexpected number of results" @@ -437,12 +437,12 @@ def test_qbase(): def test_traversal_filter_left_hand_statement(): - nescafe = Coffee(name="Nescafe2", price=99).save() - nescafe_gold = Coffee(name="Nescafe gold", price=11).save() + nescafe = Coffee(name="Nescafe2", price=99).save_async() + nescafe_gold = Coffee(name="Nescafe gold", price=11).save_async() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() - biedronka = Supplier(name="Biedronka", delivery_cost=5).save() - lidl = Supplier(name="Lidl", delivery_cost=3).save() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save_async() + biedronka = Supplier(name="Biedronka", delivery_cost=5).save_async() + lidl = Supplier(name="Lidl", delivery_cost=3).save_async() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(biedronka) @@ -456,12 +456,12 @@ def test_traversal_filter_left_hand_statement(): def test_fetch_relations(): - arabica = Species(name="Arabica").save() - robusta = Species(name="Robusta").save() - nescafe = Coffee(name="Nescafe 1000", price=99).save() - nescafe_gold = Coffee(name="Nescafe 1001", price=11).save() + arabica = Species(name="Arabica").save_async() + robusta = Species(name="Robusta").save_async() + nescafe = Coffee(name="Nescafe 1000", price=99).save_async() + nescafe_gold = Coffee(name="Nescafe 1001", price=11).save_async() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save_async() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) diff --git a/test/test_migration_neo4j_5.py b/test/test_migration_neo4j_5.py index a4730329..7f36a619 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/test_migration_neo4j_5.py @@ -4,13 +4,13 @@ IntegerProperty, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, ) -from neomodel.core import db +from neomodel._async.core import adb -class Album(StructuredNode): +class Album(StructuredNodeAsync): name = StringProperty() @@ -18,7 +18,7 @@ class Released(StructuredRel): year = IntegerProperty() -class Band(StructuredNode): +class Band(StructuredNodeAsync): name = StringProperty() released = RelationshipTo(Album, relation_type="RELEASED", model=Released) @@ -35,7 +35,7 @@ def test_read_elements_id(): # Validate id properties # Behaviour is dependent on Neo4j version - if db.database_version.startswith("4"): + if adb.database_version.startswith("4"): # Nodes' ids assert lex_hives.id == int(lex_hives.element_id) assert lex_hives.id == the_hives.released.single().id diff --git a/test/test_models.py b/test/test_models.py index 827c705a..3e804e3f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -8,15 +8,14 @@ DateProperty, IntegerProperty, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, - install_labels, ) -from neomodel.core import db +from neomodel._async.core import adb from neomodel.exceptions import RequiredProperty, UniqueProperty -class User(StructuredNode): +class User(StructuredNodeAsync): email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) @@ -29,12 +28,12 @@ def email_alias(self, value): self.email = value -class NodeWithoutProperty(StructuredNode): +class NodeWithoutProperty(StructuredNodeAsync): pass def test_issue_233(): - class BaseIssue233(StructuredNode): + class BaseIssue233(StructuredNodeAsync): __abstract_node__ = True def __getitem__(self, item): @@ -43,7 +42,7 @@ def __getitem__(self, item): class Issue233(BaseIssue233): uid = StringProperty(unique_index=True, required=True) - i = Issue233(uid="testgetitem").save() + i = Issue233(uid="testgetitem").save_async() assert i["uid"] == "testgetitem" @@ -54,7 +53,7 @@ def test_issue_72(): def test_required(): try: - User(age=3).save() + User(age=3).save_async() except RequiredProperty: assert True else: @@ -70,7 +69,7 @@ def test_repr_and_str(): def test_get_and_get_or_none(): u = User(email="robin@test.com", age=3) - assert u.save() + assert u.save_async() rob = User.nodes.get(email="robin@test.com") assert rob.email == "robin@test.com" assert rob.age == 3 @@ -84,9 +83,9 @@ def test_get_and_get_or_none(): def test_first_and_first_or_none(): u = User(email="matt@test.com", age=24) - assert u.save() + assert u.save_async() u2 = User(email="tbrady@test.com", age=40) - assert u2.save() + assert u2.save_async() tbrady = User.nodes.order_by("-age").first() assert tbrady.email == "tbrady@test.com" assert tbrady.age == 40 @@ -103,12 +102,12 @@ def test_bare_init_without_save(): If a node model is initialised without being saved, accessing its `element_id` should return None. """ - assert(User().element_id is None) + assert User().element_id is None def test_save_to_model(): u = User(email="jim@test.com", age=3) - assert u.save() + assert u.save_async() assert u.element_id is not None assert u.email == "jim@test.com" assert u.age == 3 @@ -116,60 +115,60 @@ def test_save_to_model(): def test_save_node_without_properties(): n = NodeWithoutProperty() - assert n.save() + assert n.save_async() assert n.element_id is not None def test_unique(): - install_labels(User) - User(email="jim1@test.com", age=3).save() + adb.install_labels_async(User) + User(email="jim1@test.com", age=3).save_async() with raises(UniqueProperty): - User(email="jim1@test.com", age=3).save() + User(email="jim1@test.com", age=3).save_async() def test_update_unique(): - u = User(email="jimxx@test.com", age=3).save() - u.save() # this shouldn't fail + u = User(email="jimxx@test.com", age=3).save_async() + u.save_async() # this shouldn't fail def test_update(): - user = User(email="jim2@test.com", age=3).save() + user = User(email="jim2@test.com", age=3).save_async() assert user user.email = "jim2000@test.com" - user.save() + user.save_async() jim = User.nodes.get(email="jim2000@test.com") assert jim assert jim.email == "jim2000@test.com" def test_save_through_magic_property(): - user = User(email_alias="blah@test.com", age=8).save() + user = User(email_alias="blah@test.com", age=8).save_async() assert user.email_alias == "blah@test.com" user = User.nodes.get(email="blah@test.com") assert user.email == "blah@test.com" assert user.email_alias == "blah@test.com" - user1 = User(email="blah1@test.com", age=8).save() + user1 = User(email="blah1@test.com", age=8).save_async() assert user1.email_alias == "blah1@test.com" user1.email_alias = "blah2@test.com" - assert user1.save() + assert user1.save_async() user2 = User.nodes.get(email="blah2@test.com") assert user2 -class Customer2(StructuredNode): +class Customer2(StructuredNodeAsync): __label__ = "customers" email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) def test_not_updated_on_unique_error(): - install_labels(Customer2) - Customer2(email="jim@bob.com", age=7).save() - test = Customer2(email="jim1@bob.com", age=2).save() + adb.install_labels_async(Customer2) + Customer2(email="jim@bob.com", age=7).save_async() + test = Customer2(email="jim1@bob.com", age=2).save_async() test.email = "jim@bob.com" with raises(UniqueProperty): - test.save() + test.save_async() customers = Customer2.nodes.all() assert customers[0].email != customers[1].email assert Customer2.nodes.get(email="jim@bob.com").age == 7 @@ -181,18 +180,18 @@ class Customer3(Customer2): address = StringProperty() assert Customer3.__label__ == "Customer3" - c = Customer3(email="test@test.com").save() - assert "customers" in c.labels() - assert "Customer3" in c.labels() + c = Customer3(email="test@test.com").save_async() + assert "customers" in c.labels_async() + assert "Customer3" in c.labels_async() c = Customer2.nodes.get(email="test@test.com") assert isinstance(c, Customer2) - assert "customers" in c.labels() - assert "Customer3" in c.labels() + assert "customers" in c.labels_async() + assert "Customer3" in c.labels_async() def test_refresh(): - c = Customer2(email="my@email.com", age=16).save() + c = Customer2(email="my@email.com", age=16).save_async() c.my_custom_prop = "value" copy = Customer2.nodes.get(email="my@email.com") copy.age = 20 @@ -200,37 +199,37 @@ def test_refresh(): assert c.age == 16 - c.refresh() + c.refresh_async() assert c.age == 20 assert c.my_custom_prop == "value" c = Customer2.inflate(c.element_id) c.age = 30 - c.refresh() + c.refresh_async() assert c.age == 20 - if db.database_version.startswith("4"): + if adb.database_version.startswith("4"): c = Customer2.inflate(999) else: c = Customer2.inflate("4:xxxxxx:999") with raises(Customer2.DoesNotExist): - c.refresh() + c.refresh_async() def test_setting_value_to_none(): - c = Customer2(email="alice@bob.com", age=42).save() + c = Customer2(email="alice@bob.com", age=42).save_async() assert c.age is not None c.age = None - c.save() + c.save_async() copy = Customer2.nodes.get(email="alice@bob.com") assert copy.age is None def test_inheritance(): - class User(StructuredNode): + class User(StructuredNodeAsync): __abstract_node__ = True name = StringProperty(unique_index=True) @@ -239,20 +238,20 @@ class Shopper(User): def credit_account(self, amount): self.balance = self.balance + int(amount) - self.save() + self.save_async() - jim = Shopper(name="jimmy", balance=300).save() + jim = Shopper(name="jimmy", balance=300).save_async() jim.credit_account(50) assert Shopper.__label__ == "Shopper" assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 - assert len(jim.labels()) == 1 - assert jim.labels()[0] == "Shopper" + assert len(jim.labels_async()) == 1 + assert jim.labels_async()[0] == "Shopper" def test_inherited_optional_labels(): - class BaseOptional(StructuredNode): + class BaseOptional(StructuredNodeAsync): __optional_labels__ = ["Alive"] name = StringProperty(unique_index=True) @@ -262,15 +261,15 @@ class ExtendedOptional(BaseOptional): def credit_account(self, amount): self.balance = self.balance + int(amount) - self.save() + self.save_async() - henry = ExtendedOptional(name="henry", balance=300).save() + henry = ExtendedOptional(name="henry", balance=300).save_async() henry.credit_account(50) assert ExtendedOptional.__label__ == "ExtendedOptional" assert henry.balance == 350 assert len(henry.inherited_labels()) == 2 - assert len(henry.labels()) == 2 + assert len(henry.labels_async()) == 2 assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"} @@ -287,41 +286,41 @@ def credit_account(self, amount): self.balance = self.balance + int(amount) self.save() - class Shopper2(StructuredNode, UserMixin, CreditMixin): + class Shopper2(StructuredNodeAsync, UserMixin, CreditMixin): pass - jim = Shopper2(name="jimmy", balance=300).save() + jim = Shopper2(name="jimmy", balance=300).save_async() jim.credit_account(50) assert Shopper2.__label__ == "Shopper2" assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 - assert len(jim.labels()) == 1 - assert jim.labels()[0] == "Shopper2" + assert len(jim.labels_async()) == 1 + assert jim.labels_async()[0] == "Shopper2" def test_date_property(): - class DateTest(StructuredNode): + class DateTest(StructuredNodeAsync): birthdate = DateProperty() - user = DateTest(birthdate=datetime.now()).save() + user = DateTest(birthdate=datetime.now()).save_async() def test_reserved_property_keys(): error_match = r".*is not allowed as it conflicts with neomodel internals.*" with raises(ValueError, match=error_match): - class ReservedPropertiesDeletedNode(StructuredNode): + class ReservedPropertiesDeletedNode(StructuredNodeAsync): deleted = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesIdNode(StructuredNode): + class ReservedPropertiesIdNode(StructuredNodeAsync): id = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesElementIdNode(StructuredNode): + class ReservedPropertiesElementIdNode(StructuredNodeAsync): element_id = StringProperty() with raises(ValueError, match=error_match): diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index fb00675d..6d48de70 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -1,15 +1,15 @@ from multiprocessing.pool import ThreadPool as Pool -from neomodel import StringProperty, StructuredNode, db +from neomodel import StringProperty, StructuredNodeAsync, adb -class ThingyMaBob(StructuredNode): +class ThingyMaBob(StructuredNodeAsync): name = StringProperty(unique_index=True, required=True) def thing_create(name): name = str(name) - (thing,) = ThingyMaBob.get_or_create({"name": name}) + (thing,) = ThingyMaBob.get_or_create_async({"name": name}) return thing.name, name @@ -18,4 +18,4 @@ def test_concurrency(): results = p.map(thing_create, range(50)) for returned, sent in results: assert returned == sent - db.close_connection() + adb.close_connection() diff --git a/test/test_paths.py b/test/test_paths.py index 8c6fef28..6e20d949 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -1,60 +1,75 @@ -from neomodel import (StringProperty, StructuredNode, UniqueIdProperty, - db, RelationshipTo, IntegerProperty, NeomodelPath, StructuredRel) +from neomodel import ( + IntegerProperty, + NeomodelPath, + RelationshipTo, + StringProperty, + StructuredNodeAsync, + StructuredRel, + UniqueIdProperty, + adb, +) + class PersonLivesInCity(StructuredRel): """ Relationship with data that will be instantiated as "stand-alone" """ + some_num = IntegerProperty(index=True, default=12) -class CountryOfOrigin(StructuredNode): + +class CountryOfOrigin(StructuredNodeAsync): code = StringProperty(unique_index=True, required=True) -class CityOfResidence(StructuredNode): + +class CityOfResidence(StructuredNodeAsync): name = StringProperty(required=True) - country = RelationshipTo(CountryOfOrigin, 'FROM_COUNTRY') + country = RelationshipTo(CountryOfOrigin, "FROM_COUNTRY") -class PersonOfInterest(StructuredNode): + +class PersonOfInterest(StructuredNodeAsync): uid = UniqueIdProperty() name = StringProperty(unique_index=True) age = IntegerProperty(index=True, default=0) - country = RelationshipTo(CountryOfOrigin, 'IS_FROM') - city = RelationshipTo(CityOfResidence, 'LIVES_IN', model=PersonLivesInCity) + country = RelationshipTo(CountryOfOrigin, "IS_FROM") + city = RelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) def test_path_instantiation(): """ - Neo4j driver paths should be instantiated as neomodel paths, with all of - their nodes and relationships resolved to their Python objects wherever + Neo4j driver paths should be instantiated as neomodel paths, with all of + their nodes and relationships resolved to their Python objects wherever such a mapping is available. """ - c1=CountryOfOrigin(code="GR").save() - c2=CountryOfOrigin(code="FR").save() - - ct1 = CityOfResidence(name="Athens", country = c1).save() - ct2 = CityOfResidence(name="Paris", country = c2).save() + c1 = CountryOfOrigin(code="GR").save_async() + c2 = CountryOfOrigin(code="FR").save_async() + ct1 = CityOfResidence(name="Athens", country=c1).save_async() + ct2 = CityOfResidence(name="Paris", country=c2).save_async() - p1 = PersonOfInterest(name="Bill", age=22).save() + p1 = PersonOfInterest(name="Bill", age=22).save_async() p1.country.connect(c1) p1.city.connect(ct1) - p2 = PersonOfInterest(name="Jean", age=28).save() + p2 = PersonOfInterest(name="Jean", age=28).save_async() p2.country.connect(c2) p2.city.connect(ct2) - p3 = PersonOfInterest(name="Bo", age=32).save() + p3 = PersonOfInterest(name="Bo", age=32).save_async() p3.country.connect(c1) p3.city.connect(ct2) - p4 = PersonOfInterest(name="Drop", age=16).save() + p4 = PersonOfInterest(name="Drop", age=16).save_async() p4.country.connect(c1) p4.city.connect(ct2) # Retrieve a single path - q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects = True) + q = adb.cypher_query( + "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + resolve_objects=True, + ) path_object = q[0][0][0] path_nodes = path_object.nodes @@ -68,12 +83,11 @@ def test_path_instantiation(): assert type(path_rels[0]) is PersonLivesInCity assert type(path_rels[1]) is StructuredRel - c1.delete() - c2.delete() - ct1.delete() - ct2.delete() - p1.delete() - p2.delete() - p3.delete() - p4.delete() - + c1.delete_async() + c2.delete_async() + ct1.delete_async() + ct2.delete_async() + p1.delete_async() + p2.delete_async() + p3.delete_async() + p4.delete_async() diff --git a/test/test_properties.py b/test/test_properties.py index 454ada26..e594cb60 100644 --- a/test/test_properties.py +++ b/test/test_properties.py @@ -3,7 +3,7 @@ from pytest import mark, raises from pytz import timezone -from neomodel import StructuredNode, config, db +from neomodel import StructuredNodeAsync, adb, config from neomodel.exceptions import ( DeflateError, InflateError, @@ -61,18 +61,18 @@ def test_string_property_exceeds_max_length(): def test_string_property_w_choice(): - class TestChoices(StructuredNode): + class TestChoices(StructuredNodeAsync): SEXES = {"F": "Female", "M": "Male", "O": "Other"} sex = StringProperty(required=True, choices=SEXES) try: - TestChoices(sex="Z").save() + TestChoices(sex="Z").save_async() except DeflateError as e: assert "choice" in str(e) else: assert False, "DeflateError not raised." - node = TestChoices(sex="M").save() + node = TestChoices(sex="M").save_async() assert node.get_sex_display() == "Male" @@ -186,22 +186,22 @@ def test_json(): def test_default_value(): - class DefaultTestValue(StructuredNode): + class DefaultTestValue(StructuredNodeAsync): name_xx = StringProperty(default="jim", index=True) a = DefaultTestValue() assert a.name_xx == "jim" - a.save() + a.save_async() def test_default_value_callable(): def uid_generator(): return "xx" - class DefaultTestValueTwo(StructuredNode): + class DefaultTestValueTwo(StructuredNodeAsync): uid = StringProperty(default=uid_generator, index=True) - a = DefaultTestValueTwo().save() + a = DefaultTestValueTwo().save_async() assert a.uid == "xx" @@ -214,27 +214,27 @@ def __str__(self): return Foo() - class DefaultTestValueThree(StructuredNode): + class DefaultTestValueThree(StructuredNodeAsync): uid = StringProperty(default=factory, index=True) x = DefaultTestValueThree() assert x.uid == "123" - x.save() + x.save_async() assert x.uid == "123" - x.refresh() + x.refresh_async() assert x.uid == "123" def test_independent_property_name(): - class TestDBNamePropertyNode(StructuredNode): + class TestDBNamePropertyNode(StructuredNodeAsync): name_ = StringProperty(db_property="name") x = TestDBNamePropertyNode() x.name_ = "jim" - x.save() + x.save_async() # check database property name on low level - results, meta = db.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") + results, meta = adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["name"] == "jim" @@ -245,27 +245,27 @@ class TestDBNamePropertyNode(StructuredNode): assert TestDBNamePropertyNode.nodes.filter(name_="jim").all()[0].name_ == x.name_ assert TestDBNamePropertyNode.nodes.get(name_="jim").name_ == x.name_ - x.delete() + x.delete_async() def test_independent_property_name_get_or_create(): - class TestNode(StructuredNode): + class TestNode(StructuredNodeAsync): uid = UniqueIdProperty() name_ = StringProperty(db_property="name", required=True) # create the node - TestNode.get_or_create({"uid": 123, "name_": "jim"}) + TestNode.get_or_create_async({"uid": 123, "name_": "jim"}) # test that the node is retrieved correctly - x = TestNode.get_or_create({"uid": 123, "name_": "jim"})[0] + x = TestNode.get_or_create_async({"uid": 123, "name_": "jim"})[0] # check database property name on low level - results, meta = db.cypher_query("MATCH (n:TestNode) RETURN n") + results, meta = adb.cypher_query("MATCH (n:TestNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties # delete node afterwards - x.delete() + x.delete_async() @mark.parametrize("normalized_class", (NormalizedProperty,)) @@ -338,14 +338,14 @@ def test_uid_property(): myuid = prop.default_value() assert len(myuid) - class CheckMyId(StructuredNode): + class CheckMyId(StructuredNodeAsync): uid = UniqueIdProperty() - cmid = CheckMyId().save() + cmid = CheckMyId().save_async() assert len(cmid.uid) -class ArrayProps(StructuredNode): +class ArrayProps(StructuredNodeAsync): uid = StringProperty(unique_index=True) untyped_arr = ArrayProperty() typed_arr = ArrayProperty(IntegerProperty()) @@ -353,20 +353,20 @@ class ArrayProps(StructuredNode): def test_array_properties(): # untyped - ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save() + ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save_async() assert "Tim" in ap1.untyped_arr ap1 = ArrayProps.nodes.get(uid="1") assert "Tim" in ap1.untyped_arr # typed try: - ArrayProps(uid="2", typed_arr=["a", "b"]).save() + ArrayProps(uid="2", typed_arr=["a", "b"]).save_async() except DeflateError as e: assert "unsaved node" in str(e) else: assert False, "DeflateError not raised." - ap2 = ArrayProps(uid="2", typed_arr=[1, 2]).save() + ap2 = ArrayProps(uid="2", typed_arr=[1, 2]).save_async() assert 1 in ap2.typed_arr ap2 = ArrayProps.nodes.get(uid="2") assert 2 in ap2.typed_arr @@ -378,16 +378,16 @@ def test_illegal_array_base_prop_raises(): def test_indexed_array(): - class IndexArray(StructuredNode): + class IndexArray(StructuredNodeAsync): ai = ArrayProperty(unique_index=True) - b = IndexArray(ai=[1, 2]).save() + b = IndexArray(ai=[1, 2]).save_async() c = IndexArray.nodes.get(ai=[1, 2]) assert b.element_id == c.element_id def test_unique_index_prop_not_required(): - class ConstrainedTestNode(StructuredNode): + class ConstrainedTestNode(StructuredNodeAsync): required_property = StringProperty(required=True) unique_property = StringProperty(unique_index=True) unique_required_property = StringProperty(unique_index=True, required=True) @@ -396,46 +396,46 @@ class ConstrainedTestNode(StructuredNode): # Create a node with a missing required property with raises(RequiredProperty): x = ConstrainedTestNode(required_property="required", unique_property="unique") - x.save() + x.save_async() # Create a node with a missing unique (but not required) property. x = ConstrainedTestNode() x.required_property = "required" x.unique_required_property = "unique and required" x.unconstrained_property = "no contraints" - x.save() + x.save_async() # check database property name on low level - results, meta = db.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") + results, meta = adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["unique_required_property"] == "unique and required" # delete node afterwards - x.delete() + x.delete_async() def test_unique_index_prop_enforced(): - class UniqueNullableNameNode(StructuredNode): + class UniqueNullableNameNode(StructuredNodeAsync): name = StringProperty(unique_index=True) # Nameless x = UniqueNullableNameNode() - x.save() + x.save_async() y = UniqueNullableNameNode() - y.save() + y.save_async() # Named z = UniqueNullableNameNode(name="named") - z.save() + z.save_async() with raises(UniqueProperty): a = UniqueNullableNameNode(name="named") - a.save() + a.save_async() # Check nodes are in database - results, meta = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") + results, meta = adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") assert len(results) == 3 # Delete nodes afterwards - x.delete() - y.delete() - z.delete() + x.delete_async() + y.delete_async() + z.delete_async() diff --git a/test/test_relationship_models.py b/test/test_relationship_models.py index 82760c73..59669ae9 100644 --- a/test/test_relationship_models.py +++ b/test/test_relationship_models.py @@ -9,7 +9,7 @@ Relationship, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, ) @@ -30,20 +30,20 @@ def post_save(self): HOOKS_CALLED["post_save"] += 1 -class Badger(StructuredNode): +class Badger(StructuredNodeAsync): name = StringProperty(unique_index=True) friend = Relationship("Badger", "FRIEND", model=FriendRel) hates = RelationshipTo("Stoat", "HATES", model=HatesRel) -class Stoat(StructuredNode): +class Stoat(StructuredNodeAsync): name = StringProperty(unique_index=True) hates = RelationshipTo("Badger", "HATES", model=HatesRel) def test_either_connect_with_rel_model(): - paul = Badger(name="Paul").save() - tom = Badger(name="Tom").save() + paul = Badger(name="Paul").save_async() + tom = Badger(name="Tom").save_async() # creating rels new_rel = tom.friend.disconnect(paul) @@ -64,8 +64,8 @@ def test_either_connect_with_rel_model(): def test_direction_connect_with_rel_model(): - paul = Badger(name="Paul the badger").save() - ian = Stoat(name="Ian the stoat").save() + paul = Badger(name="Paul the badger").save_async() + ian = Stoat(name="Ian the stoat").save_async() rel = ian.hates.connect(paul, {"reason": "thinks paul should bath more often"}) assert isinstance(rel.since, datetime) @@ -104,9 +104,9 @@ def test_direction_connect_with_rel_model(): def test_traversal_where_clause(): - phill = Badger(name="Phill the badger").save() - tim = Badger(name="Tim the badger").save() - bob = Badger(name="Bob the badger").save() + phill = Badger(name="Phill the badger").save_async() + tim = Badger(name="Tim the badger").save_async() + bob = Badger(name="Bob the badger").save_async() rel = tim.friend.connect(bob) now = datetime.now(pytz.utc) assert rel.since < now @@ -118,8 +118,8 @@ def test_traversal_where_clause(): def test_multiple_rels_exist_issue_223(): # check a badger can dislike a stoat for multiple reasons - phill = Badger(name="Phill").save() - ian = Stoat(name="Stoat").save() + phill = Badger(name="Phill").save_async() + ian = Stoat(name="Stoat").save_async() rel_a = phill.hates.connect(ian, {"reason": "a"}) rel_b = phill.hates.connect(ian, {"reason": "b"}) @@ -131,8 +131,8 @@ def test_multiple_rels_exist_issue_223(): def test_retrieve_all_rels(): - tom = Badger(name="tom").save() - ian = Stoat(name="ian").save() + tom = Badger(name="tom").save_async() + ian = Stoat(name="ian").save_async() rel_a = tom.hates.connect(ian, {"reason": "a"}) rel_b = tom.hates.connect(ian, {"reason": "b"}) @@ -147,8 +147,8 @@ def test_save_hook_on_rel_model(): HOOKS_CALLED["pre_save"] = 0 HOOKS_CALLED["post_save"] = 0 - paul = Badger(name="PaulB").save() - ian = Stoat(name="IanS").save() + paul = Badger(name="PaulB").save_async() + ian = Stoat(name="IanS").save_async() rel = ian.hates.connect(paul, {"reason": "yadda yadda"}) rel.save() diff --git a/test/test_relationships.py b/test/test_relationships.py index 92d75064..75c98c90 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -8,13 +8,13 @@ RelationshipFrom, RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, ) -from neomodel.core import db +from neomodel._async.core import adb -class PersonWithRels(StructuredNode): +class PersonWithRels(StructuredNodeAsync): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) is_from = RelationshipTo("Country", "IS_FROM") @@ -28,7 +28,7 @@ def special_power(self): return "I have no powers" -class Country(StructuredNode): +class Country(StructuredNodeAsync): code = StringProperty(unique_index=True) inhabitant = RelationshipFrom(PersonWithRels, "IS_FROM") president = RelationshipTo(PersonWithRels, "PRESIDENT", cardinality=One) @@ -42,8 +42,8 @@ def special_power(self): def test_actions_on_deleted_node(): - u = PersonWithRels(name="Jim2", age=3).save() - u.delete() + u = PersonWithRels(name="Jim2", age=3).save_async() + u.delete_async() with raises(ValueError): u.is_from.connect(None) @@ -51,14 +51,14 @@ def test_actions_on_deleted_node(): u.is_from.get() with raises(ValueError): - u.save() + u.save_async() def test_bidirectional_relationships(): - u = PersonWithRels(name="Jim", age=3).save() + u = PersonWithRels(name="Jim", age=3).save_async() assert u - de = Country(code="DE").save() + de = Country(code="DE").save_async() assert de assert not u.is_from @@ -82,17 +82,17 @@ def test_bidirectional_relationships(): def test_either_direction_connect(): - rey = PersonWithRels(name="Rey", age=3).save() - sakis = PersonWithRels(name="Sakis", age=3).save() + rey = PersonWithRels(name="Rey", age=3).save_async() + sakis = PersonWithRels(name="Sakis", age=3).save_async() rey.knows.connect(sakis) assert rey.knows.is_connected(sakis) assert sakis.knows.is_connected(rey) sakis.knows.connect(rey) - result, _ = sakis.cypher( + result, _ = sakis.cypher_async( f"""MATCH (us), (them) - WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(them)=$them + WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", {"them": rey.element_id}, ) @@ -106,10 +106,10 @@ def test_either_direction_connect(): def test_search_and_filter_and_exclude(): - fred = PersonWithRels(name="Fred", age=13).save() - zz = Country(code="ZZ").save() - zx = Country(code="ZX").save() - zt = Country(code="ZY").save() + fred = PersonWithRels(name="Fred", age=13).save_async() + zz = Country(code="ZZ").save_async() + zx = Country(code="ZX").save_async() + zt = Country(code="ZY").save_async() fred.is_from.connect(zz) fred.is_from.connect(zx) fred.is_from.connect(zt) @@ -130,21 +130,21 @@ def test_search_and_filter_and_exclude(): def test_custom_methods(): - u = PersonWithRels(name="Joe90", age=13).save() + u = PersonWithRels(name="Joe90", age=13).save_async() assert u.special_power() == "I have no powers" - u = SuperHero(name="Joe91", age=13, power="xxx").save() + u = SuperHero(name="Joe91", age=13, power="xxx").save_async() assert u.special_power() == "I have powers" assert u.special_name == "Joe91" def test_valid_reconnection(): - p = PersonWithRels(name="ElPresidente", age=93).save() + p = PersonWithRels(name="ElPresidente", age=93).save_async() assert p - pp = PersonWithRels(name="TheAdversary", age=33).save() + pp = PersonWithRels(name="TheAdversary", age=33).save_async() assert pp - c = Country(code="CU").save() + c = Country(code="CU").save_async() assert c c.president.connect(p) @@ -160,16 +160,16 @@ def test_valid_reconnection(): def test_valid_replace(): - brady = PersonWithRels(name="Tom Brady", age=40).save() + brady = PersonWithRels(name="Tom Brady", age=40).save_async() assert brady - gronk = PersonWithRels(name="Rob Gronkowski", age=28).save() + gronk = PersonWithRels(name="Rob Gronkowski", age=28).save_async() assert gronk - colbert = PersonWithRels(name="Stephen Colbert", age=53).save() + colbert = PersonWithRels(name="Stephen Colbert", age=53).save_async() assert colbert - hanks = PersonWithRels(name="Tom Hanks", age=61).save() + hanks = PersonWithRels(name="Tom Hanks", age=61).save_async() assert hanks brady.knows.connect(gronk) @@ -186,13 +186,13 @@ def test_valid_replace(): def test_props_relationship(): - u = PersonWithRels(name="Mar", age=20).save() + u = PersonWithRels(name="Mar", age=20).save_async() assert u - c = Country(code="AT").save() + c = Country(code="AT").save_async() assert c - c2 = Country(code="LA").save() + c2 = Country(code="LA").save_async() assert c2 with raises(NotImplementedError): diff --git a/test/test_relative_relationships.py b/test/test_relative_relationships.py index db78d038..81cca8fd 100644 --- a/test/test_relative_relationships.py +++ b/test/test_relative_relationships.py @@ -1,19 +1,18 @@ -from neomodel import RelationshipTo, StringProperty, StructuredNode +from neomodel import RelationshipTo, StringProperty, StructuredNodeAsync +from neomodel.test_relationships import Country -from .test_relationships import Country - -class Cat(StructuredNode): +class Cat(StructuredNodeAsync): name = StringProperty() # Relationship is defined using a relative class path is_from = RelationshipTo(".test_relationships.Country", "IS_FROM") def test_relative_relationship(): - a = Cat(name="snufkin").save() + a = Cat(name="snufkin").save_async() assert a - c = Country(code="MG").save() + c = Country(code="MG").save_async() assert c # connecting an instance of the class defined above diff --git a/test/test_scripts.py b/test/test_scripts.py index 66594489..cd182bb8 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -3,21 +3,21 @@ from neomodel import ( RelationshipTo, StringProperty, - StructuredNode, + StructuredNodeAsync, StructuredRel, config, - db, - install_labels, - util, ) +from neomodel._async.core import adb class ScriptsTestRel(StructuredRel): - some_unique_property = StringProperty(unique_index=db.version_is_higher_than("5.7")) + some_unique_property = StringProperty( + unique_index=adb.version_is_higher_than("5.7") + ) some_index_property = StringProperty(index=True) -class ScriptsTestNode(StructuredNode): +class ScriptsTestNode(StructuredNodeAsync): personal_id = StringProperty(unique_index=True) name = StringProperty(index=True) rel = RelationshipTo("ScriptsTestNode", "REL", model=ScriptsTestRel) @@ -34,26 +34,26 @@ def test_neomodel_install_labels(): assert result.returncode == 0 result = subprocess.run( - ["neomodel_install_labels", "test/test_scripts.py", "--db", db.url], + ["neomodel_install_labels", "test/test_scripts.py", "--db", adb.url], capture_output=True, text=True, check=False, ) assert result.returncode == 0 assert "Setting up indexes and constraints" in result.stdout - constraints = db.list_constraints() + constraints = adb.list_constraints_async() parsed_constraints = [ (element["type"], element["labelsOrTypes"], element["properties"]) for element in constraints ] assert ("UNIQUENESS", ["ScriptsTestNode"], ["personal_id"]) in parsed_constraints - if db.version_is_higher_than("5.7"): + if adb.version_is_higher_than("5.7"): assert ( "RELATIONSHIP_UNIQUENESS", ["REL"], ["some_unique_property"], ) in parsed_constraints - indexes = db.list_indexes() + indexes = adb.lise_indexes_async() parsed_indexes = [ (element["labelsOrTypes"], element["properties"]) for element in indexes ] @@ -81,8 +81,8 @@ def test_neomodel_remove_labels(): "Dropping unique constraint and index on label ScriptsTestNode" in result.stdout ) assert result.returncode == 0 - constraints = db.list_constraints() - indexes = db.list_indexes(exclude_token_lookup=True) + constraints = adb.list_constraints_async() + indexes = adb.lise_indexes_async(exclude_token_lookup=True) assert len(constraints) == 0 assert len(indexes) == 0 @@ -98,18 +98,18 @@ def test_neomodel_inspect_database(): assert "usage: neomodel_inspect_database" in result.stdout assert result.returncode == 0 - util.clear_neo4j_database(db) - install_labels(ScriptsTestNode) - install_labels(ScriptsTestRel) + adb.clear_neo4j_database_async() + adb.install_labels_async(ScriptsTestNode) + adb.install_labels_async(ScriptsTestRel) # Create a few nodes and a rel, with indexes and constraints - node1 = ScriptsTestNode(personal_id="1", name="test").save() - node2 = ScriptsTestNode(personal_id="2", name="test").save() + node1 = ScriptsTestNode(personal_id="1", name="test").save_async() + node2 = ScriptsTestNode(personal_id="2", name="test").save_async() node1.rel.connect(node2, {"some_unique_property": "1", "some_index_property": "2"}) # Create a node with all the parsable property types # Also create a node with no properties - db.cypher_query( + adb.cypher_query( """ CREATE (:EveryPropertyTypeNode { string_property: "Hello World", @@ -142,7 +142,7 @@ def test_neomodel_inspect_database(): # Check that all the expected lines are here file_path = ( "test/data/neomodel_inspect_database_output.txt" - if db.version_is_higher_than("5.7") + if adb.version_is_higher_than("5.7") else "test/data/neomodel_inspect_database_output_pre_5_7.txt" ) with open(file_path, "r") as f: diff --git a/test/test_transactions.py b/test/test_transactions.py index 0481e2a7..f7780c6d 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -3,17 +3,11 @@ from neo4j.exceptions import ClientError, TransactionError from pytest import raises -from neomodel import ( - StringProperty, - StructuredNode, - UniqueProperty, - config, - db, - install_labels, -) +from neomodel import StringProperty, StructuredNodeAsync, UniqueProperty +from neomodel._async.core import adb -class APerson(StructuredNode): +class APerson(StructuredNodeAsync): name = StringProperty(unique_index=True) @@ -21,29 +15,29 @@ def test_rollback_and_commit_transaction(): for p in APerson.nodes: p.delete() - APerson(name="Roger").save() + APerson(name="Roger").save_async() - db.begin() - APerson(name="Terry S").save() - db.rollback() + adb.begin() + APerson(name="Terry S").save_async() + adb.rollback() assert len(APerson.nodes) == 1 - db.begin() - APerson(name="Terry S").save() - db.commit() + adb.begin() + APerson(name="Terry S").save_async() + adb.commit() assert len(APerson.nodes) == 2 -@db.transaction +@adb.transaction def in_a_tx(*names): for n in names: - APerson(name=n).save() + APerson(name=n).save_async() def test_transaction_decorator(): - install_labels(APerson) + adb.install_labels_async(APerson) for p in APerson.nodes: p.delete() @@ -59,64 +53,64 @@ def test_transaction_decorator(): def test_transaction_as_a_context(): - with db.transaction: - APerson(name="Tim").save() + with adb.transaction: + APerson(name="Tim").save_async() assert APerson.nodes.filter(name="Tim") with raises(UniqueProperty): - with db.transaction: - APerson(name="Tim").save() + with adb.transaction: + APerson(name="Tim").save_async() def test_query_inside_transaction(): for p in APerson.nodes: p.delete() - with db.transaction: - APerson(name="Alice").save() - APerson(name="Bob").save() + with adb.transaction: + APerson(name="Alice").save_async() + APerson(name="Bob").save_async() assert len([p.name for p in APerson.nodes]) == 2 def test_read_transaction(): - APerson(name="Johnny").save() + APerson(name="Johnny").save_async() - with db.read_transaction: + with adb.read_transaction: people = APerson.nodes.all() assert people with raises(TransactionError): - with db.read_transaction: + with adb.read_transaction: with raises(ClientError) as e: - APerson(name="Gina").save() + APerson(name="Gina").save_async() assert e.value.code == "Neo.ClientError.Statement.AccessMode" def test_write_transaction(): - with db.write_transaction: - APerson(name="Amelia").save() + with adb.write_transaction: + APerson(name="Amelia").save_async() amelia = APerson.nodes.get(name="Amelia") assert amelia def double_transaction(): - db.begin() + adb.begin() with raises(SystemError, match=r"Transaction in progress"): - db.begin() + adb.begin() - db.rollback() + adb.rollback() -@db.transaction.with_bookmark +@adb.transaction.with_bookmark def in_a_tx(*names): for n in names: - APerson(name=n).save() + APerson(name=n).save_async() -def test_bookmark_transaction_decorator(skip_neo4j_before_330): +def test_bookmark_transaction_decorator(): for p in APerson.nodes: p.delete() @@ -132,34 +126,34 @@ def test_bookmark_transaction_decorator(skip_neo4j_before_330): assert "Jane" not in [p.name for p in APerson.nodes] -def test_bookmark_transaction_as_a_context(skip_neo4j_before_330): - with db.transaction as transaction: - APerson(name="Tanya").save() +def test_bookmark_transaction_as_a_context(): + with adb.transaction as transaction: + APerson(name="Tanya").save_async() assert isinstance(transaction.last_bookmark, Bookmarks) assert APerson.nodes.filter(name="Tanya") with raises(UniqueProperty): - with db.transaction as transaction: - APerson(name="Tanya").save() + with adb.transaction as transaction: + APerson(name="Tanya").save_async() assert not hasattr(transaction, "last_bookmark") @pytest.fixture def spy_on_db_begin(monkeypatch): spy_calls = [] - original_begin = db.begin + original_begin = adb.begin def begin_spy(*args, **kwargs): spy_calls.append((args, kwargs)) return original_begin(*args, **kwargs) - monkeypatch.setattr(db, "begin", begin_spy) + monkeypatch.setattr(adb, "begin", begin_spy) return spy_calls -def test_bookmark_passed_in_to_context(skip_neo4j_before_330, spy_on_db_begin): - transaction = db.transaction +def test_bookmark_passed_in_to_context(spy_on_db_begin): + transaction = adb.transaction with transaction: pass @@ -175,13 +169,13 @@ def test_bookmark_passed_in_to_context(skip_neo4j_before_330, spy_on_db_begin): ) -def test_query_inside_bookmark_transaction(skip_neo4j_before_330): +def test_query_inside_bookmark_transaction(): for p in APerson.nodes: p.delete() - with db.transaction as transaction: - APerson(name="Alice").save() - APerson(name="Bob").save() + with adb.transaction as transaction: + APerson(name="Alice").save_async() + APerson(name="Bob").save_async() assert len([p.name for p in APerson.nodes]) == 2 From 6bcb97ab930f51e6164e04b128163becd7c58bcb Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 23 Nov 2023 10:52:09 +0100 Subject: [PATCH 02/73] Fix session scope fixtures ; and a first test --- neomodel/_async/core.py | 12 ++++++------ test/async_/conftest.py | 17 +++++++++++++++-- test/{ => async_}/test_cypher.py | 0 3 files changed, 21 insertions(+), 8 deletions(-) rename test/{ => async_}/test_cypher.py (100%) diff --git a/neomodel/_async/core.py b/neomodel/_async/core.py index 1df0e771..606125b6 100644 --- a/neomodel/_async/core.py +++ b/neomodel/_async/core.py @@ -64,7 +64,7 @@ async def wrapper(self, *args, **kwargs): elif config.DATABASE_URL: await _db.set_connection_async(url=config.DATABASE_URL) - return func(self, *args, **kwargs) + return await func(self, *args, **kwargs) return wrapper @@ -115,7 +115,7 @@ async def set_connection_async(self, url: str = None, driver: AsyncDriver = None # Getting the information about the database version requires a connection to the database self._database_version = None self._database_edition = None - self._update_database_version_async() + await self._update_database_version_async() def _parse_driver_from_url(self, url: str) -> None: """Parse the driver information from the given URL and initialize the driver. @@ -437,7 +437,7 @@ async def cypher_query_async( ) else: # Otherwise create a new session in a with to dispose of it after it has been run - with await self.driver.session( + async with self.driver.session( database=self._database_name, impersonated_user=self.impersonated_user ) as session: results, meta = await self._run_cypher_query_async( @@ -464,7 +464,7 @@ async def _run_cypher_query_async( # Retrieve the data start = time.time() response: AsyncResult = await session.run(query, params) - results, meta = [list(r.values()) for r in response], response.keys() + results, meta = [list(r.values()) async for r in response], response.keys() end = time.time() if resolve_objects: @@ -1303,7 +1303,7 @@ async def create_or_update_async(cls, *props, **kwargs): # fetch and build instance for each result results = await adb.cypher_query_async(query, params) - return [cls.inflate(r[0]) for r in results[0]] + return [cls.inflate(r[0]) async for r in results[0]] async def cypher_async(self, query, params=None): """ @@ -1372,7 +1372,7 @@ async def get_or_create_async(cls, *props, **kwargs): # fetch and build instance for each result results = await adb.cypher_query_async(query, params) - return [cls.inflate(r[0]) for r in results[0]] + return [cls.inflate(r[0]) async for r in results[0]] @classmethod def inflate(cls, node): diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 587559fa..28db8673 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -1,6 +1,8 @@ +import asyncio import os import warnings +import pytest import pytest_asyncio from neomodel import config @@ -8,6 +10,7 @@ @pytest_asyncio.fixture(scope="session", autouse=True) +@pytest.mark.asyncio async def setup_neo4j_session(request): """ Provides initial connection to the database and sets up the rest of the test suite @@ -31,8 +34,8 @@ async def setup_neo4j_session(request): raise SystemError( "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." ) - else: - await adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) + + await adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) await adb.cypher_query_async( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" @@ -43,6 +46,16 @@ async def setup_neo4j_session(request): @pytest_asyncio.fixture(scope="session", autouse=True) +@pytest.mark.asyncio async def cleanup(): yield await adb.close_connection_async() + + +@pytest.fixture(scope="session") +def event_loop(): + """Overrides pytest default function scoped event loop""" + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() diff --git a/test/test_cypher.py b/test/async_/test_cypher.py similarity index 100% rename from test/test_cypher.py rename to test/async_/test_cypher.py From 3d30cd1d191b6f8d042428dfbd99bc2be568417a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 8 Dec 2023 10:27:48 +0100 Subject: [PATCH 03/73] First test working --- neomodel/_sync/__init__.py | 1 + neomodel/_sync/core.py | 1493 ++++++++++++++++++++++++++ pyproject.toml | 1 + test/_async_compat/__init__.py | 13 + test/_async_compat/mark_decorator.py | 18 + test/async_/conftest.py | 5 +- test/async_/test_cypher.py | 13 +- test/sync/conftest.py | 62 ++ test/sync/test_cypher.py | 155 +++ 9 files changed, 1753 insertions(+), 8 deletions(-) create mode 100644 neomodel/_sync/__init__.py create mode 100644 neomodel/_sync/core.py create mode 100644 test/_async_compat/__init__.py create mode 100644 test/_async_compat/mark_decorator.py create mode 100644 test/sync/conftest.py create mode 100644 test/sync/test_cypher.py diff --git a/neomodel/_sync/__init__.py b/neomodel/_sync/__init__.py new file mode 100644 index 00000000..95bbd58a --- /dev/null +++ b/neomodel/_sync/__init__.py @@ -0,0 +1 @@ +# from neomodel._async.core import adb diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py new file mode 100644 index 00000000..e4a9990b --- /dev/null +++ b/neomodel/_sync/core.py @@ -0,0 +1,1493 @@ +import logging +import os +import sys +import time +import warnings +from itertools import combinations +from threading import local +from typing import Optional, Sequence, Tuple +from urllib.parse import quote, unquote, urlparse + +from neo4j import ( + DEFAULT_DATABASE, + Driver, + GraphDatabase, + Result, + Session, + Transaction, + basic_auth, +) +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired +from neo4j.graph import Node, Path, Relationship + +from neomodel import config +from neomodel.exceptions import ( + ConstraintValidationFailed, + DoesNotExist, + FeatureNotSupported, + NodeClassAlreadyDefined, + NodeClassNotDefined, + RelationshipClassNotDefined, + UniqueProperty, +) +from neomodel.hooks import hooks +from neomodel.properties import Property, PropertyManager +from neomodel.util import ( + _get_node_properties, + _UnsavedNode, + classproperty, + deprecated, + version_tag_to_integer, +) + +logger = logging.getLogger(__name__) + +RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" +INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" +CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" +STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" + + +# make sure the connection url has been set prior to executing the wrapped function +def ensure_connection(func): + def wrapper(self, *args, **kwargs): + # Sort out where to find url + if hasattr(self, "db"): + _db = self.db + else: + _db = self + + if not _db.driver: + if hasattr(config, "DRIVER") and config.DRIVER: + _db.set_connection_async(driver=config.DRIVER) + elif config.DATABASE_URL: + _db.set_connection_async(url=config.DATABASE_URL) + + return func(self, *args, **kwargs) + + return wrapper + + +class Database(local): + """ + A singleton object via which all operations from neomodel to the Neo4j backend are handled with. + """ + + _NODE_CLASS_REGISTRY = {} + + def __init__(self): + self._active_transaction = None + self.url = None + self.driver = None + self._session = None + self._pid = None + self._database_name = DEFAULT_DATABASE + self.protocol_version = None + self._database_version = None + self._database_edition = None + self.impersonated_user = None + + def set_connection_async(self, url: str = None, driver: Driver = None): + """ + Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. + + Args: + url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname. + When provided, a Neo4j driver instance will be created by neomodel. + + driver (neo4j.Driver): Optionally, a pre-created driver instance. + When provided, neomodel will not create a driver instance but use this one instead. + """ + if driver: + self.driver = driver + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + elif url: + self._parse_driver_from_url(url=url) + + self._pid = os.getpid() + self._active_transaction = None + # Set to default database if it hasn't been set before + if self._database_name is None: + self._database_name = DEFAULT_DATABASE + + # Getting the information about the database version requires a connection to the database + self._database_version = None + self._database_edition = None + self._update_database_version_async() + + def _parse_driver_from_url(self, url: str) -> None: + """Parse the driver information from the given URL and initialize the driver. + + Args: + url (str): The URL to parse. + + Raises: + ValueError: If the URL format is not as expected. + + Returns: + None - Sets the driver and database_name as class properties + """ + p_start = url.replace(":", "", 1).find(":") + 2 + p_end = url.rfind("@") + password = url[p_start:p_end] + url = url.replace(password, quote(password)) + parsed_url = urlparse(url) + + valid_schemas = [ + "bolt", + "bolt+s", + "bolt+ssc", + "bolt+routing", + "neo4j", + "neo4j+s", + "neo4j+ssc", + ] + + if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas: + credentials, hostname = parsed_url.netloc.rsplit("@", 1) + username, password = credentials.split(":") + password = unquote(password) + database_name = parsed_url.path.strip("/") + else: + raise ValueError( + f"Expecting url format: bolt://user:password@localhost:7687 got {url}" + ) + + options = { + "auth": basic_auth(username, password), + "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, + "connection_timeout": config.CONNECTION_TIMEOUT, + "keep_alive": config.KEEP_ALIVE, + "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, + "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, + "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, + "resolver": config.RESOLVER, + "user_agent": config.USER_AGENT, + } + + if "+s" not in parsed_url.scheme: + options["encrypted"] = config.ENCRYPTED + options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + + self.driver = GraphDatabase.driver( + parsed_url.scheme + "://" + hostname, **options + ) + self.url = url + # The database name can be provided through the url or the config + if database_name == "": + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + else: + self._database_name = database_name + + def close_connection_async(self): + """ + Closes the currently open driver. + The driver should always be closed at the end of the application's lifecyle. + """ + self._database_version = None + self._database_edition = None + self._database_name = None + self.driver.close() + self.driver = None + + @property + def database_version(self): + if self._database_version is None: + self._update_database_version_async() + + return self._database_version + + @property + def database_edition(self): + if self._database_edition is None: + self._update_database_version_async() + + return self._database_edition + + @property + def transaction(self): + """ + Returns the current transaction object + """ + return TransactionProxyAsync(self) + + @property + def write_transaction(self): + return TransactionProxyAsync(self, access_mode="WRITE") + + @property + def read_transaction(self): + return TransactionProxyAsync(self, access_mode="READ") + + def impersonate(self, user: str) -> "ImpersonationHandler": + """All queries executed within this context manager will be executed as impersonated user + + Args: + user (str): User to impersonate + + Returns: + ImpersonationHandler: Context manager to set/unset the user to impersonate + """ + if self.database_edition != "enterprise": + raise FeatureNotSupported( + "Impersonation is only available in Neo4j Enterprise edition" + ) + return ImpersonationHandler(self, impersonated_user=user) + + @ensure_connection + def begin_async(self, access_mode=None, **parameters): + """ + Begins a new transaction. Raises SystemError if a transaction is already active. + """ + if ( + hasattr(self, "_active_transaction") + and self._active_transaction is not None + ): + raise SystemError("Transaction in progress") + self._session: Session = self.driver.session( + default_access_mode=access_mode, + database=self._database_name, + impersonated_user=self.impersonated_user, + **parameters, + ) + self._active_transaction: Transaction = self._session.begin_transaction() + + @ensure_connection + def commit_async(self): + """ + Commits the current transaction and closes its session + + :return: last_bookmarks + """ + try: + self._active_transaction.commit() + last_bookmarks: Bookmarks = self._session.last_bookmarks() + finally: + # In case when something went wrong during + # committing changes to the database + # we have to close an active transaction and session. + self._active_transaction.close() + self._session.close() + self._active_transaction = None + self._session = None + + return last_bookmarks + + @ensure_connection + def rollback_async(self): + """ + Rolls back the current transaction and closes its session + """ + try: + self._active_transaction.rollback() + finally: + # In case when something went wrong during changes rollback, + # we have to close an active transaction and session + self._active_transaction.close() + self._session.close() + self._active_transaction = None + self._session = None + + def _update_database_version_async(self): + """ + Updates the database server information when it is required + """ + try: + results = self.cypher_query_async( + "CALL dbms.components() yield versions, edition return versions[0], edition" + ) + self._database_version = results[0][0][0] + self._database_edition = results[0][0][1] + except ServiceUnavailable: + # The database server is not running yet + pass + + def _object_resolution(self, object_to_resolve): + """ + Performs in place automatic object resolution on a result + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures and Path objects. Not meant to be called + directly, used primarily by _result_resolution. + + :param object_to_resolve: A result as returned by cypher_query. + :type Any: + + :return: An instantiated object. + """ + # Below is the original comment that came with the code extracted in + # this method. It is not very clear but I decided to keep it just in + # case + # + # + # For some reason, while the type of `a_result_attribute[1]` + # as reported by the neo4j driver is `Node` for Node-type data + # retrieved from the database. + # When the retrieved data are Relationship-Type, + # the returned type is `abc.[REL_LABEL]` which is however + # a descendant of Relationship. + # Consequently, the type checking was changed for both + # Node, Relationship objects + if isinstance(object_to_resolve, Node): + return self._NODE_CLASS_REGISTRY[ + frozenset(object_to_resolve.labels) + ].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Relationship): + rel_type = frozenset([object_to_resolve.type]) + return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Path): + from neomodel.path import NeomodelPath + + return NeomodelPath(object_to_resolve) + + if isinstance(object_to_resolve, list): + return self._result_resolution([object_to_resolve]) + + return object_to_resolve + + def _result_resolution(self, result_list): + """ + Performs in place automatic object resolution on a set of results + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures. Not meant to be called directly, + used primarily by cypher_query. + + :param result_list: A list of results as returned by cypher_query. + :type list: + + :return: A list of instantiated objects. + """ + + # Object resolution occurs in-place + for a_result_item in enumerate(result_list): + for a_result_attribute in enumerate(a_result_item[1]): + try: + # Primitive types should remain primitive types, + # Nodes to be resolved to native objects + resolved_object = a_result_attribute[1] + + resolved_object = self._object_resolution(resolved_object) + + result_list[a_result_item[0]][ + a_result_attribute[0] + ] = resolved_object + + except KeyError as exc: + # Not being able to match the label set of a node with a known object results + # in a KeyError in the internal dictionary used for resolution. If it is impossible + # to match, then raise an exception with more details about the error. + if isinstance(a_result_attribute[1], Node): + raise NodeClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + if isinstance(a_result_attribute[1], Relationship): + raise RelationshipClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + return result_list + + @ensure_connection + def cypher_query_async( + self, + query, + params=None, + handle_unique=True, + retry_on_session_expire=False, + resolve_objects=False, + ) -> (list[list], Tuple[str, ...]): + """ + Runs a query on the database and returns a list of results and their headers. + + :param query: A CYPHER query + :type: str + :param params: Dictionary of parameters + :type: dict + :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors + :type: bool + :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired. + If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance. + :type: bool + :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically + :type: bool + + :return: A tuple containing a list of results and a tuple of headers. + """ + + if self._active_transaction: + # Use current session is a transaction is currently active + results, meta = self._run_cypher_query_async( + self._active_transaction, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + else: + # Otherwise create a new session in a with to dispose of it after it has been run + with self.driver.session( + database=self._database_name, impersonated_user=self.impersonated_user + ) as session: + results, meta = self._run_cypher_query_async( + session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + + return results, meta + + def _run_cypher_query_async( + self, + session: Session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) -> (list[list], Tuple[str, ...]): + try: + # Retrieve the data + start = time.time() + response: Result = session.run(query, params) + results, meta = [list(r.values()) for r in response], response.keys() + end = time.time() + + if resolve_objects: + # Do any automatic resolution required + results = self._result_resolution(results) + + except ClientError as e: + if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": + if "already exists with label" in e.message and handle_unique: + raise UniqueProperty(e.message) from e + + raise ConstraintValidationFailed(e.message) from e + exc_info = sys.exc_info() + raise exc_info[1].with_traceback(exc_info[2]) + except SessionExpired: + if retry_on_session_expire: + self.set_connection_async(url=self.url) + return self.cypher_query_async( + query=query, + params=params, + handle_unique=handle_unique, + retry_on_session_expire=False, + ) + raise + + tte = end - start + if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float( + os.environ.get("NEOMODEL_SLOW_QUERIES", 0) + ): + logger.debug( + "query: " + + query + + "\nparams: " + + repr(params) + + f"\ntook: {tte:.2g}s\n" + ) + + return results, meta + + def get_id_method(self) -> str: + if self.database_version.startswith("4"): + return "id" + else: + return "elementId" + + def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict]: + """Returns all indexes existing in the database + + Arguments: + exclude_token_lookup[bool]: Exclude automatically create token lookup indexes + + Returns: + Sequence[dict]: List of dictionaries, each entry being an index definition + """ + indexes, meta_indexes = self.cypher_query_async("SHOW INDEXES") + indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] + + if exclude_token_lookup: + indexes_as_dict = [ + obj for obj in indexes_as_dict if obj["type"] != "LOOKUP" + ] + + return indexes_as_dict + + def list_constraints_async(self) -> Sequence[dict]: + """Returns all constraints existing in the database + + Returns: + Sequence[dict]: List of dictionaries, each entry being a constraint definition + """ + constraints, meta_constraints = self.cypher_query_async("SHOW CONSTRAINTS") + constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] + + return constraints_as_dict + + def version_is_higher_than(self, version_tag: str) -> bool: + """Returns true if the database version is higher or equal to a given tag + + Args: + version_tag (str): The version to compare against + + Returns: + bool: True if the database version is higher or equal to the given version + """ + return version_tag_to_integer(self.database_version) >= version_tag_to_integer( + version_tag + ) + + def edition_is_enterprise(self) -> bool: + """Returns true if the database edition is enterprise + + Returns: + bool: True if the database edition is enterprise + """ + return self.database_edition == "enterprise" + + def change_neo4j_password_async(self, user, new_password): + self.cypher_query_async(f"ALTER USER {user} SET PASSWORD '{new_password}'") + + def clear_neo4j_database_async(self, clear_constraints=False, clear_indexes=False): + self.cypher_query_async( + """ + MATCH (a) + CALL { WITH a DETACH DELETE a } + IN TRANSACTIONS OF 5000 rows + """ + ) + if clear_constraints: + drop_constraints_async() + if clear_indexes: + drop_indexes_async() + + def drop_constraints_async(self, quiet=True, stdout=None): + """ + Discover and drop all constraints. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + results, meta = self.cypher_query_async("SHOW CONSTRAINTS") + + results_as_dict = [dict(zip(meta, row)) for row in results] + for constraint in results_as_dict: + self.cypher_query_async("DROP CONSTRAINT " + constraint["name"]) + if not quiet: + stdout.write( + ( + " - Dropping unique constraint and index" + f" on label {constraint['labelsOrTypes'][0]}" + f" with property {constraint['properties'][0]}.\n" + ) + ) + if not quiet: + stdout.write("\n") + + def drop_indexes_async(self, quiet=True, stdout=None): + """ + Discover and drop all indexes, except the automatically created token lookup indexes. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + indexes = self.list_indexes_async(exclude_token_lookup=True) + for index in indexes: + self.cypher_query_async("DROP INDEX " + index["name"]) + if not quiet: + stdout.write( + f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' + ) + if not quiet: + stdout.write("\n") + + def remove_all_labels_async(self, stdout=None): + """ + Calls functions for dropping constraints and indexes. + + :param stdout: output stream + :return: None + """ + + if not stdout: + stdout = sys.stdout + + stdout.write("Dropping constraints...\n") + self.drop_constraints_async(quiet=False, stdout=stdout) + + stdout.write("Dropping indexes...\n") + self.drop_indexes_async(quiet=False, stdout=stdout) + + def install_all_labels_async(self, stdout=None): + """ + Discover all subclasses of StructuredNode in your application and execute install_labels on each. + Note: code must be loaded (imported) in order for a class to be discovered. + + :param stdout: output stream + :return: None + """ + + if not stdout or stdout is None: + stdout = sys.stdout + + def subsub(cls): # recursively return all subclasses + subclasses = cls.__subclasses__() + if not subclasses: # base case: no more subclasses + return [] + return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] + + stdout.write("Setting up indexes and constraints...\n\n") + + i = 0 + for cls in subsub(StructuredNodeAsync): + stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") + install_labels_async(cls, quiet=False, stdout=stdout) + i += 1 + + if i: + stdout.write("\n") + + stdout.write(f"Finished {i} classes.\n") + + def install_labels_async(self, cls, quiet=True, stdout=None): + """ + Setup labels with indexes and constraints for a given class + + :param cls: StructuredNode class + :type: class + :param quiet: (default true) enable standard output + :param stdout: stdout stream + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + if not hasattr(cls, "__label__"): + if not quiet: + stdout.write( + f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" + ) + return + + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + self._install_node_async(cls, name, property, quiet, stdout) + + for _, relationship in cls.defined_properties( + aliases=False, rels=True, properties=False + ).items(): + self._install_relationship_async(cls, relationship, quiet, stdout) + + def _create_node_index_async(self, label: str, property_name: str, stdout): + try: + self.cypher_query_async( + f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + def _create_node_constraint_async(self, label: str, property_name: str, stdout): + try: + self.cypher_query_async( + f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} + FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + def _create_relationship_index_async( + self, relationship_type: str, property_name: str, stdout + ): + try: + self.cypher_query_async( + f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + def _create_relationship_constraint_async( + self, relationship_type: str, property_name: str, stdout + ): + if self.version_is_higher_than("5.7"): + try: + self.cypher_query_async( + f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} + FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + else: + raise FeatureNotSupported( + f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." + ) + + def _install_node_async(self, cls, name, property, quiet, stdout): + # Create indexes and constraints for node property + db_property = property.db_property or name + if property.index: + if not quiet: + stdout.write( + f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + self._create_node_index_async( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + self._create_node_constraint_async( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + def _install_relationship_async(self, cls, relationship, quiet, stdout): + # Create indexes and constraints for relationship property + relationship_cls = relationship.definition["model"] + if relationship_cls is not None: + relationship_type = relationship.definition["relation_type"] + for prop_name, property in relationship_cls.defined_properties( + aliases=False, rels=False + ).items(): + db_property = property.db_property or prop_name + if property.index: + if not quiet: + stdout.write( + f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + self._create_relationship_index_async( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + self._create_relationship_constraint_async( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + + +# Create a singleton instance of the database object +adb = Database() + + +# Deprecated methods +def change_neo4j_password_async(db, user, new_password): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.change_neo4j_password_async(user, new_password) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.change_neo4j_password_async(user, new_password) + + +def clear_neo4j_database_async(db, clear_constraints=False, clear_indexes=False): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.clear_neo4j_database_async(clear_constraints, clear_indexes) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.clear_neo4j_database_async(clear_constraints, clear_indexes) + + +def drop_constraints_async(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.drop_constraints_async(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + adb.drop_constraints_async(quiet, stdout) + + +def drop_indexes_async(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.drop_indexes_async(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + adb.drop_indexes_async(quiet, stdout) + + +def remove_all_labels_async(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.remove_all_labels_async(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + adb.remove_all_labels_async(stdout) + + +def install_labels_async(cls, quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.install_labels_async(cls, quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + adb.install_labels_async(cls, quiet, stdout) + + +def install_all_labels_async(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.install_all_labels_async(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + adb.install_all_labels_async(stdout) + + +class TransactionProxyAsync: + bookmarks: Optional[Bookmarks] = None + + def __init__(self, db: Database, access_mode=None): + self.db = db + self.access_mode = access_mode + + @ensure_connection + def __enter__(self): + self.db.begin_async(access_mode=self.access_mode, bookmarks=self.bookmarks) + self.bookmarks = None + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_value: + self.db.rollback_async() + + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) + + if not exc_value: + self.last_bookmark = self.db.commit_async() + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + @property + def with_bookmark(self): + return BookmarkingTransactionProxyAsync(self.db, self.access_mode) + + +class ImpersonationHandler: + def __init__(self, db: Database, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self): + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + +class BookmarkingTransactionProxyAsync(TransactionProxyAsync): + def __call__(self, func): + def wrapper(*args, **kwargs): + self.bookmarks = kwargs.pop("bookmarks", None) + + with self: + result = func(*args, **kwargs) + self.last_bookmark = None + + return result, self.last_bookmark + + return wrapper + + +# TODO : Either deprecate auto_install_labels +# Or make it work with async +class NodeMeta(type): + def __new__(mcs, name, bases, namespace): + namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) + cls = super().__new__(mcs, name, bases, namespace) + cls.DoesNotExist._model_class = cls + + if hasattr(cls, "__abstract_node__"): + delattr(cls, "__abstract_node__") + else: + if "deleted" in namespace: + raise ValueError( + "Property name 'deleted' is not allowed as it conflicts with neomodel internals." + ) + elif "id" in namespace: + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif "element_id" in namespace: + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + for key, value in ( + (x, y) for x, y in namespace.items() if isinstance(y, Property) + ): + value.name, value.owner = key, cls + if hasattr(value, "setup") and callable(value.setup): + value.setup() + + # cache various groups of properies + cls.__required_properties__ = tuple( + name + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + if property.required or property.unique_index + ) + cls.__all_properties__ = tuple( + cls.defined_properties(aliases=False, rels=False).items() + ) + cls.__all_aliases__ = tuple( + cls.defined_properties(properties=False, rels=False).items() + ) + cls.__all_relationships__ = tuple( + cls.defined_properties(aliases=False, properties=False).items() + ) + + cls.__label__ = namespace.get("__label__", name) + cls.__optional_labels__ = namespace.get("__optional_labels__", []) + + # TODO : See previous TODO comment + # if config.AUTO_INSTALL_LABELS: + # await install_labels_async(cls, quiet=False) + + build_class_registry(cls) + + return cls + + +def build_class_registry(cls): + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) + + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if label_set not in adb._NODE_CLASS_REGISTRY: + adb._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined(cls, adb._NODE_CLASS_REGISTRY) + + +NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) + + +class StructuredNodeAsync(NodeBase): + """ + Base class for all node definitions to inherit from. + + If you want to create your own abstract classes set: + __abstract_node__ = True + """ + + # static properties + + __abstract_node__ = True + + # magic methods + + def __init__(self, *args, **kwargs): + if "deleted" in kwargs: + raise ValueError("deleted property is reserved for neomodel") + + for key, val in self.__all_relationships__: + self.__dict__[key] = val.build_manager(self, key) + + super().__init__(*args, **kwargs) + + def __eq__(self, other): + if not isinstance(other, (StructuredNodeAsync,)): + return False + if hasattr(self, "element_id") and hasattr(other, "element_id"): + return self.element_id == other.element_id + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return f"<{self.__class__.__name__}: {self}>" + + def __str__(self): + return repr(self.__properties__) + + # dynamic properties + + @classproperty + def nodes(cls): + """ + Returns a NodeSet object representing all nodes of the classes label + :return: NodeSet + :rtype: NodeSet + """ + from neomodel.match import NodeSet + + return NodeSet(cls) + + @property + def element_id(self): + if hasattr(self, "element_id_property"): + return ( + int(self.element_id_property) + if adb.database_version.startswith("4") + else self.element_id_property + ) + return None + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self): + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + # methods + + @classmethod + def _build_merge_query( + cls, merge_params, update_existing=False, lazy=False, relationship=None + ): + """ + Get a tuple of a CYPHER query and a params dict for the specified MERGE query. + + :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". + :type merge_params: list of dict + :param update_existing: True to update properties of existing nodes, default False to keep existing values. + :type update_existing: bool + :rtype: tuple + """ + query_params = dict(merge_params=merge_params) + n_merge_labels = ":".join(cls.inherited_labels()) + n_merge_prm = ", ".join( + ( + f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + for p in cls.__required_properties__ + ) + ) + n_merge = f"n:{n_merge_labels} { {n_merge_prm}} " + if relationship is None: + # create "simple" unwind query + query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " + else: + # validate relationship + if not isinstance(relationship.source, StructuredNodeAsync): + raise ValueError( + f"relationship source [{repr(relationship.source)}] is not a StructuredNode" + ) + relation_type = relationship.definition.get("relation_type") + if not relation_type: + raise ValueError( + "No relation_type is specified on provided relationship" + ) + + from neomodel.match import _rel_helper + + query_params["source_id"] = relationship.source.element_id + query = f"MATCH (source:{relationship.source.__label__}) WHERE {adb.get_id_method()}(source) = $source_id\n " + query += "WITH source\n UNWIND $merge_params as params \n " + query += "MERGE " + query += _rel_helper( + lhs="source", + rhs=n_merge, + ident=None, + relation_type=relation_type, + direction=relationship.definition["direction"], + ) + + query += "ON CREATE SET n = params.create\n " + # if update_existing, write properties on match as well + if update_existing is True: + query += "ON MATCH SET n += params.update\n" + + # close query + if lazy: + query += f"RETURN {adb.get_id_method()}(n)" + else: + query += "RETURN n" + + return query, query_params + + @classmethod + def create_async(cls, *props, **kwargs): + """ + Call to CREATE with parameters map. A new instance will be created and saved. + + :param props: dict of properties to create the nodes. + :type props: tuple + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type: bool + :rtype: list + """ + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + lazy = kwargs.get("lazy", False) + # create mapped query + query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" + + # close query + if lazy: + query += f" RETURN {adb.get_id_method()}(n)" + else: + query += " RETURN n" + + results = [] + for item in [ + cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props + ]: + node, _ = adb.cypher_query_async(query, {"create_params": item}) + results.extend(node[0]) + + nodes = [cls.inflate(node) for node in results] + + if not lazy and hasattr(cls, "post_create"): + for node in nodes: + node.post_create() + + return nodes + + @classmethod + def create_or_update_async(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, + this is an atomic operation. If an instance already exists all optional properties specified will be updated. + + Note that the post_create hook isn't called after create_or_update + + :param props: List of dict arguments to get or create the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query, make sure to update only explicitly specified properties + create_or_update_params = [] + for specified, deflated in [ + (p, cls.deflate(p, skip_empty=True)) for p in props + ]: + create_or_update_params.append( + { + "create": deflated, + "update": dict( + (k, v) for k, v in deflated.items() if k in specified + ), + } + ) + query, params = cls._build_merge_query( + create_or_update_params, + update_existing=True, + relationship=relationship, + lazy=lazy, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = adb.cypher_query_async(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + def cypher_async(self, query, params=None): + """ + Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. + + :param query: cypher query string + :type: string + :param params: query parameters + :type: dict + :return: list containing query results + :rtype: list + """ + self._pre_action_check("cypher") + params = params or {} + params.update({"self": self.element_id}) + return adb.cypher_query_async(query, params) + + @hooks + def delete_async(self): + """ + Delete a node and its relationships + + :return: True + """ + self._pre_action_check("delete") + self.cypher_async( + f"MATCH (self) WHERE {adb.get_id_method()}(self)=$self DETACH DELETE self" + ) + delattr(self, "element_id_property") + self.deleted = True + return True + + @classmethod + def get_or_create_async(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, + this is an atomic operation. + Parameters must contain all required properties, any non required properties with defaults will be generated. + + Note that the post_create hook isn't called after get_or_create + + :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create + the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query + get_or_create_params = [ + {"create": cls.deflate(p, skip_empty=True)} for p in props + ] + query, params = cls._build_merge_query( + get_or_create_params, relationship=relationship, lazy=lazy + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = adb.cypher_query_async(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + @classmethod + def inflate(cls, node): + """ + Inflate a raw neo4j_driver node to a neomodel node + :param node: + :return: node object + """ + # support lazy loading + if isinstance(node, str) or isinstance(node, int): + snode = cls() + snode.element_id_property = node + else: + node_properties = _get_node_properties(node) + props = {} + for key, prop in cls.__all_properties__: + # map property name from database to object property + db_property = prop.db_property or key + + if db_property in node_properties: + props[key] = prop.inflate(node_properties[db_property], node) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + + snode = cls(**props) + snode.element_id_property = node.element_id + + return snode + + @classmethod + def inherited_labels(cls): + """ + Return list of labels from nodes class hierarchy. + + :return: list + """ + return [ + scls.__label__ + for scls in cls.mro() + if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") + ] + + @classmethod + def inherited_optional_labels(cls): + """ + Return list of optional labels from nodes class hierarchy. + + :return: list + :rtype: list + """ + return [ + label + for scls in cls.mro() + for label in getattr(scls, "__optional_labels__", []) + if not hasattr(scls, "__abstract_node__") + ] + + def labels_async(self): + """ + Returns list of labels tied to the node from neo4j. + + :return: list of labels + :rtype: list + """ + self._pre_action_check("labels") + return self.cypher_async( + f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" + )[0][0][0] + + def _pre_action_check(self, action): + if hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on deleted node" + ) + if not hasattr(self, "element_id"): + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on unsaved node" + ) + + def refresh_async(self): + """ + Reload the node from neo4j + """ + self._pre_action_check("refresh") + if hasattr(self, "element_id"): + request = self.cypher_async( + f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" + )[0] + if not request or not request[0]: + raise self.__class__.DoesNotExist("Can't refresh non existent node") + node = self.inflate(request[0][0]) + for key, val in node.__properties__.items(): + setattr(self, key, val) + else: + raise ValueError("Can't refresh unsaved node") + + @hooks + def save_async(self): + """ + Save the node to neo4j or raise an exception + + :return: the node instance + """ + + # create or update instance node + if hasattr(self, "element_id_property"): + # update + params = self.deflate(self.__properties__, self) + query = f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self\n" + + if params: + query += "SET " + query += ",\n".join([f"n.{key} = ${key}" for key in params]) + query += "\n" + if self.inherited_labels(): + query += "\n".join( + [f"SET n:`{label}`" for label in self.inherited_labels()] + ) + self.cypher_async(query, params) + elif hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.save() attempted on deleted node" + ) + else: # create + result = self.create_async(self.__properties__) + created_node = result[0] + self.element_id_property = created_node.element_id + return self diff --git a/pyproject.toml b/pyproject.toml index 83eaad9e..d0e9df49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ changelog = "https://github.com/neo4j-contrib/neomodel/releases" [project.optional-dependencies] dev = [ + "unasync", "pytest>=7.1", "pytest-asyncio", "pytest-cov>=4.0", diff --git a/test/_async_compat/__init__.py b/test/_async_compat/__init__.py new file mode 100644 index 00000000..d5053965 --- /dev/null +++ b/test/_async_compat/__init__.py @@ -0,0 +1,13 @@ +from .mark_decorator import ( + AsyncTestDecorators, + TestDecorators, + mark_async_test, + mark_sync_test, +) + +__all__ = [ + "AsyncTestDecorators", + "mark_async_test", + "mark_sync_test", + "TestDecorators", +] diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py new file mode 100644 index 00000000..1195baa5 --- /dev/null +++ b/test/_async_compat/mark_decorator.py @@ -0,0 +1,18 @@ +import pytest + +mark_async_test = pytest.mark.asyncio + + +def mark_sync_test(f): + return f + + +class AsyncTestDecorators: + mark_async_only_test = mark_async_test + + +class TestDecorators: + @staticmethod + def mark_async_only_test(f): + skip_decorator = pytest.mark.skip("Async only test") + return skip_decorator(f) diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 28db8673..88156e7b 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -1,6 +1,7 @@ import asyncio import os import warnings +from test._async_compat import mark_async_test import pytest import pytest_asyncio @@ -10,7 +11,7 @@ @pytest_asyncio.fixture(scope="session", autouse=True) -@pytest.mark.asyncio +@mark_async_test async def setup_neo4j_session(request): """ Provides initial connection to the database and sets up the rest of the test suite @@ -46,7 +47,7 @@ async def setup_neo4j_session(request): @pytest_asyncio.fixture(scope="session", autouse=True) -@pytest.mark.asyncio +@mark_async_test async def cleanup(): yield await adb.close_connection_async() diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index ef07bdc8..1b7934c9 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -1,4 +1,5 @@ import builtins +from test._async_compat import mark_async_test import pytest from neo4j.exceptions import ClientError as CypherError @@ -36,7 +37,7 @@ def mocked_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", mocked_import) -@pytest.mark.asyncio +@mark_async_test async def test_cypher_async(): """ test result format is backward compatible with earlier versions of neomodel @@ -59,7 +60,7 @@ async def test_cypher_async(): assert "a" in meta and "b" in meta -@pytest.mark.asyncio +@mark_async_test async def test_cypher_syntax_error_async(): jim = await User2(email="jim1@test.com").save_async() try: @@ -73,7 +74,7 @@ async def test_cypher_syntax_error_async(): assert False, "CypherError not raised." -@pytest.mark.asyncio +@mark_async_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) async def test_pandas_not_installed_async(hide_available_pkg): with pytest.raises(ImportError): @@ -88,7 +89,7 @@ async def test_pandas_not_installed_async(hide_available_pkg): ) -@pytest.mark.asyncio +@mark_async_test async def test_pandas_integration_async(): from neomodel.integration.pandas import to_dataframe, to_series @@ -127,7 +128,7 @@ async def test_pandas_integration_async(): assert df["name"].tolist() == ["jimla", "jimlo"] -@pytest.mark.asyncio +@mark_async_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) async def test_numpy_not_installed_async(hide_available_pkg): with pytest.raises(ImportError): @@ -142,7 +143,7 @@ async def test_numpy_not_installed_async(hide_available_pkg): ) -@pytest.mark.asyncio +@mark_async_test async def test_numpy_integration_async(): from neomodel.integration.numpy import to_ndarray diff --git a/test/sync/conftest.py b/test/sync/conftest.py new file mode 100644 index 00000000..6ccd8d0d --- /dev/null +++ b/test/sync/conftest.py @@ -0,0 +1,62 @@ +import asyncio +import os +import warnings +from test._async_compat import mark_async_test + +import pytest +import pytest_asyncio + +from neomodel import config +from neomodel._async.core import adb + + +@pytest_asyncio.fixture(scope="session", autouse=True) +@mark_async_test +def setup_neo4j_session(request): + """ + Provides initial connection to the database and sets up the rest of the test suite + + :param request: The request object. Please see `_ + :type Request object: For more information please see `_ + """ + + warnings.simplefilter("default") + + config.DATABASE_URL = os.environ.get( + "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" + ) + config.AUTO_INSTALL_LABELS = True + + # Clear the database if required + database_is_populated, _ = adb.cypher_query_async( + "MATCH (a) return count(a)>0 as database_is_populated" + ) + if database_is_populated[0][0] and not request.config.getoption("resetdb"): + raise SystemError( + "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." + ) + + adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) + + adb.cypher_query_async( + "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" + ) + if adb.database_edition == "enterprise": + adb.cypher_query_async("GRANT ROLE publisher TO troygreene") + adb.cypher_query_async("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + + +@pytest_asyncio.fixture(scope="session", autouse=True) +@mark_async_test +def cleanup(): + yield + adb.close_connection_async() + + +@pytest.fixture(scope="session") +def event_loop(): + """Overrides pytest default function scoped event loop""" + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py new file mode 100644 index 00000000..5bd4afb9 --- /dev/null +++ b/test/sync/test_cypher.py @@ -0,0 +1,155 @@ +import builtins +from test._async_compat import mark_async_test + +import pytest +from neo4j.exceptions import ClientError as CypherError +from numpy import ndarray +from pandas import DataFrame, Series + +from neomodel import StringProperty, StructuredNodeAsync +from neomodel._async.core import adb + + +class User2(StructuredNodeAsync): + name = StringProperty() + email = StringProperty() + + +class UserPandas(StructuredNodeAsync): + name = StringProperty() + email = StringProperty() + + +class UserNP(StructuredNodeAsync): + name = StringProperty() + email = StringProperty() + + +@pytest.fixture +def hide_available_pkg(monkeypatch, request): + import_orig = builtins.__import__ + + def mocked_import(name, *args, **kwargs): + if name == request.param: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + +@mark_async_test +def test_cypher_async(): + """ + test result format is backward compatible with earlier versions of neomodel + """ + + jim = User2(email="jim1@test.com").save_async() + data, meta = jim.cypher_async( + f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self RETURN a.email" + ) + assert data[0][0] == "jim1@test.com" + assert "a.email" in meta + + data, meta = jim.cypher_async( + f""" + MATCH (a) WHERE {adb.get_id_method()}(a)=$self + MATCH (a)<-[:USER2]-(b) + RETURN a, b, 3 + """ + ) + assert "a" in meta and "b" in meta + + +@mark_async_test +def test_cypher_syntax_error_async(): + jim = User2(email="jim1@test.com").save_async() + try: + jim.cypher_async(f"MATCH a WHERE {adb.get_id_method()}(a)={ self} RETURN xx") + except CypherError as e: + assert hasattr(e, "message") + assert hasattr(e, "code") + else: + assert False, "CypherError not raised." + + +@mark_async_test +@pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) +def test_pandas_not_installed_async(hide_available_pkg): + with pytest.raises(ImportError): + with pytest.warns( + UserWarning, + match="The neomodel.integration.pandas module expects pandas to be installed", + ): + from neomodel.integration.pandas import to_dataframe + + _ = to_dataframe(adb.cypher_query_async("MATCH (a) RETURN a.name AS name")) + + +@mark_async_test +def test_pandas_integration_async(): + from neomodel.integration.pandas import to_dataframe, to_series + + jimla = UserPandas(email="jimla@test.com", name="jimla").save_async() + jimlo = UserPandas(email="jimlo@test.com", name="jimlo").save_async() + + # Test to_dataframe + df = to_dataframe( + adb.cypher_query_async( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + ) + ) + + assert isinstance(df, DataFrame) + assert df.shape == (2, 2) + assert df["name"].tolist() == ["jimla", "jimlo"] + + # Also test passing an index and dtype to to_dataframe + df = to_dataframe( + adb.cypher_query_async( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + ), + index=df["email"], + dtype=str, + ) + + assert df.index.inferred_type == "string" + + # Next test to_series + series = to_series( + adb.cypher_query_async("MATCH (a:UserPandas) RETURN a.name AS name") + ) + + assert isinstance(series, Series) + assert series.shape == (2,) + assert df["name"].tolist() == ["jimla", "jimlo"] + + +@mark_async_test +@pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) +def test_numpy_not_installed_async(hide_available_pkg): + with pytest.raises(ImportError): + with pytest.warns( + UserWarning, + match="The neomodel.integration.numpy module expects pandas to be installed", + ): + from neomodel.integration.numpy import to_ndarray + + _ = to_ndarray(adb.cypher_query_async("MATCH (a) RETURN a.name AS name")) + + +@mark_async_test +def test_numpy_integration_async(): + from neomodel.integration.numpy import to_ndarray + + jimly = UserNP(email="jimly@test.com", name="jimly").save_async() + jimlu = UserNP(email="jimlu@test.com", name="jimlu").save_async() + + array = to_ndarray( + adb.cypher_query_async( + "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" + ) + ) + + assert isinstance(array, ndarray) + assert array.shape == (2, 2) + assert array[0][0] == "jimly" From 0d92dab3e97461cb8c0f87d3f9e4ee6bd091cb4a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 11 Dec 2023 13:16:19 +0100 Subject: [PATCH 04/73] Add auto unasync with pre commit --- .pre-commit-config.yaml | 15 +-- dev-scripts/make-unasync | 275 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+), 8 deletions(-) create mode 100644 dev-scripts/make-unasync diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81b274db..5dfa12a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,11 +7,10 @@ repos: rev: 5.11.5 hooks: - id: isort - # - repo: local - # hooks: - # - id: pylint - # name: pylint - # entry: pylint neomodel/ - # language: system - # always_run: true - # pass_filenames: false \ No newline at end of file + - repo: local + hooks: + - id: unasync + name: unasync + entry: dev-scripts/make-unasync + language: system + files: "^(neomodel/_async|test/async_)/.*" \ No newline at end of file diff --git a/dev-scripts/make-unasync b/dev-scripts/make-unasync new file mode 100644 index 00000000..b51f67fd --- /dev/null +++ b/dev-scripts/make-unasync @@ -0,0 +1,275 @@ +#!/usr/bin/env python + +import collections +import errno +import os +import re +import sys +import tokenize as std_tokenize +from pathlib import Path + +import unasync + +ROOT_DIR = Path(__file__).parents[1].absolute() +ASYNC_DIR = ROOT_DIR / "neomodel" / "_async" +SYNC_DIR = ROOT_DIR / "neomodel" / "_sync" +ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" +SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync" +UNASYNC_SUFFIX = ".unasync" + +PY_FILE_EXTENSIONS = {".py"} + + +# copy from unasync for customization ----------------------------------------- +# https://github.com/python-trio/unasync +# License: MIT or Apache2 + + +Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) + + +def _makedirs_existok(dir): + try: + os.makedirs(dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def _get_tokens(f): + if sys.version_info[0] == 2: + for tok in std_tokenize.generate_tokens(f.readline): + type_, string, start, end, line = tok + yield Token(type_, string, start, end, line) + else: + for tok in std_tokenize.tokenize(f.readline): + if tok.type == std_tokenize.ENCODING: + continue + yield tok + + +def _tokenize(f): + last_end = (1, 0) + for tok in _get_tokens(f): + if last_end[0] < tok.start[0]: + yield "", std_tokenize.STRING, " \\\n" + last_end = (tok.start[0], 0) + + space = "" + if tok.start > last_end: + assert tok.start[0] == last_end[0] + space = " " * (tok.start[1] - last_end[1]) + yield space, tok.type, tok.string + + last_end = tok.end + if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]: + last_end = (tok.end[0] + 1, 0) + + +def _untokenize(tokens): + return "".join(space + tokval for space, tokval in tokens) + + +# end of copy ----------------------------------------------------------------- + + +class CustomRule(unasync.Rule): + def __init__(self, *args, **kwargs): + super(CustomRule, self).__init__(*args, **kwargs) + self.out_files = [] + self.token_replacements = {} + + def _unasync_tokens(self, tokens): + # copy from unasync to fix handling of multiline strings + # https://github.com/python-trio/unasync + # License: MIT or Apache2 + + used_space = None + for space, toknum, tokval in tokens: + if tokval in ["async", "await"]: + # When removing async or await, we want to use the whitespace + # that was before async/await before the next token so that + # `print(await stuff)` becomes `print(stuff)` and not + # `print( stuff)` + used_space = space + else: + if toknum == std_tokenize.NAME: + tokval = self._unasync_name(tokval) + elif toknum == std_tokenize.STRING: + if tokval[0] == tokval[1] and len(tokval) > 2: + # multiline string (`"""..."""` or `'''...'''`) + left_quote, name, right_quote = ( + tokval[:3], + tokval[3:-3], + tokval[-3:], + ) + else: + # simple string (`"..."` or `'...'`) + left_quote, name, right_quote = ( + tokval[:1], + tokval[1:-1], + tokval[-1:], + ) + tokval = left_quote + self._unasync_string(name) + right_quote + if used_space is None: + used_space = space + yield (used_space, tokval) + used_space = None + + def _unasync_string(self, name): + start = 0 + end = 1 + out = "" + while end < len(name): + sub_name = name[start:end] + if sub_name.isidentifier(): + end += 1 + else: + if end == start + 1: + out += sub_name + start += 1 + end += 1 + else: + out += self._unasync_name(name[start : (end - 1)]) + start = end - 1 + + sub_name = name[start:] + if sub_name.isidentifier(): + out += self._unasync_name(name[start:]) + else: + out += sub_name + + # very boiled down unasync version that removes "async" and "await" + # substrings. + out = re.subn( + r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out, flags=re.MULTILINE + )[0] + # Convert doc-reference names from 'async-xyz' to 'xyz' + out = re.subn(r":ref:`async-", ":ref:`", out)[0] + return out + + def _unasync_prefix(self, name): + # Convert class names from 'AsyncXyz' to 'Xyz' + if len(name) > 5 and name.startswith("Async") and name[5].isupper(): + return name[5:] + # Convert class names from '_AsyncXyz' to '_Xyz' + elif len(name) > 6 and name.startswith("_Async") and name[6].isupper(): + return "_" + name[6:] + # Convert variable/method/function names from 'async_xyz' to 'xyz' + elif len(name) > 6 and name.startswith("async_"): + return name[6:] + return name + + def _unasync_name(self, name): + # copy from unasync to customize renaming rules + # https://github.com/python-trio/unasync + # License: MIT or Apache2 + if name in self.token_replacements: + return self.token_replacements[name] + return self._unasync_prefix(name) + + def _unasync_file(self, filepath): + # copy from unasync to append file suffix to out path + # https://github.com/python-trio/unasync + # License: MIT or Apache2 + with open(filepath, "rb") as f: + write_kwargs = {} + if sys.version_info[0] >= 3: + encoding, _ = std_tokenize.detect_encoding(f.readline) + write_kwargs["encoding"] = encoding + f.seek(0) + tokens = _tokenize(f) + tokens = self._unasync_tokens(tokens) + result = _untokenize(tokens) + outfile_path = filepath.replace(self.fromdir, self.todir) + outfile_path += UNASYNC_SUFFIX + self.out_files.append(outfile_path) + _makedirs_existok(os.path.dirname(outfile_path)) + with open(outfile_path, "w", **write_kwargs) as f: + print(result, file=f, end="") + + +def apply_unasync(files): + """Generate sync code from async code.""" + + additional_main_replacements = {"adb": "db"} + additional_test_replacements = { + "_async": "_sync", + "adb": "db", + "mark_async_test": "mark_sync_test", + } + rules = [ + CustomRule( + fromdir=str(ASYNC_DIR), + todir=str(SYNC_DIR), + additional_replacements=additional_main_replacements, + ), + CustomRule( + fromdir=str(ASYNC_INTEGRATION_TEST_DIR), + todir=str(SYNC_INTEGRATION_TEST_DIR), + additional_replacements=additional_test_replacements, + ), + ] + + if not files: + paths = list(ASYNC_DIR.rglob("*")) + paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*")) + else: + paths = [ROOT_DIR / Path(f) for f in files] + filtered_paths = [] + for path in paths: + if path.suffix in PY_FILE_EXTENSIONS: + filtered_paths.append(path) + + unasync.unasync_files(map(str, filtered_paths), rules) + + return [Path(path) for rule in rules for path in rule.out_files] + + +def apply_changes(paths): + def files_equal(path1, path2): + with open(path1, "rb") as f1: + with open(path2, "rb") as f2: + data1 = f1.read(1024) + data2 = f2.read(1024) + while data1 or data2: + if data1 != data2: + changed_paths[path1] = path2 + return False + data1 = f1.read(1024) + data2 = f2.read(1024) + return True + + changed_paths = {} + + for in_path in paths: + out_path = Path(str(in_path)[: -len(UNASYNC_SUFFIX)]) + if not out_path.is_file(): + changed_paths[in_path] = out_path + continue + if not files_equal(in_path, out_path): + changed_paths[in_path] = out_path + continue + in_path.unlink() + + for in_path, out_path in changed_paths.items(): + in_path.replace(out_path) + + return list(changed_paths.values()) + + +def main(): + files = None + if len(sys.argv) >= 1: + files = sys.argv[1:] + paths = apply_unasync(files) + changed_paths = apply_changes(paths) + + if changed_paths: + for path in changed_paths: + print("Updated " + str(path)) + exit(1) + + +if __name__ == "__main__": + main() From 8039635d9ac5e3a242e85a3c41538d0ecf408791 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 11 Dec 2023 14:24:28 +0100 Subject: [PATCH 05/73] Remove async suffix for methods in async classes --- .gitignore | 1 - .pre-commit-config.yaml | 6 +- {dev-scripts => bin}/make-unasync | 2 +- neomodel/__init__.py | 14 +- neomodel/_async/core.py | 208 +++++++++--------- neomodel/_sync/core.py | 202 ++++++++--------- neomodel/match.py | 6 +- neomodel/relationship.py | 6 +- neomodel/scripts/neomodel_inspect_database.py | 24 +- neomodel/scripts/neomodel_install_labels.py | 4 +- neomodel/scripts/neomodel_remove_labels.py | 4 +- test/async_/conftest.py | 12 +- test/async_/test_cypher.py | 38 ++-- test/sync/conftest.py | 12 +- test/sync/test_cypher.py | 32 +-- test/test_alias.py | 4 +- test/test_batch.py | 28 +-- test/test_cardinality.py | 24 +- test/test_connection.py | 22 +- test/test_contrib/test_semi_structured.py | 6 +- test/test_contrib/test_spatial_properties.py | 6 +- test/test_database_management.py | 28 +-- test/test_driver_options.py | 12 +- test/test_hooks.py | 4 +- test/test_indexing.py | 20 +- test/test_issue112.py | 4 +- test/test_issue283.py | 142 +++++------- test/test_issue600.py | 24 +- test/test_label_drop.py | 14 +- test/test_label_install.py | 52 ++--- test/test_match_api.py | 100 ++++----- test/test_models.py | 86 ++++---- test/test_multiprocessing.py | 2 +- test/test_paths.py | 32 +-- test/test_properties.py | 52 ++--- test/test_relationship_models.py | 26 +-- test/test_relationships.py | 48 ++-- test/test_relative_relationships.py | 4 +- test/test_scripts.py | 18 +- test/test_transactions.py | 34 +-- 40 files changed, 654 insertions(+), 709 deletions(-) rename {dev-scripts => bin}/make-unasync (99%) mode change 100644 => 100755 diff --git a/.gitignore b/.gitignore index a79c277d..562969f3 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,6 @@ development.env .ropeproject \#*\# .eggs -bin lib .vscode pyvenv.cfg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5dfa12a5..4601807f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,4 @@ repos: - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - repo: https://github.com/PyCQA/isort rev: 5.11.5 hooks: @@ -11,6 +7,6 @@ repos: hooks: - id: unasync name: unasync - entry: dev-scripts/make-unasync + entry: bin/make-unasync language: system files: "^(neomodel/_async|test/async_)/.*" \ No newline at end of file diff --git a/dev-scripts/make-unasync b/bin/make-unasync old mode 100644 new mode 100755 similarity index 99% rename from dev-scripts/make-unasync rename to bin/make-unasync index b51f67fd..9be3b8bc --- a/dev-scripts/make-unasync +++ b/bin/make-unasync @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import collections import errno diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 013104f6..5cf8ccdc 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -2,13 +2,13 @@ # TODO : Check imports here from neomodel._async.core import ( StructuredNodeAsync, - change_neo4j_password_async, - clear_neo4j_database_async, - drop_constraints_async, - drop_indexes_async, - install_all_labels_async, - install_labels_async, - remove_all_labels_async, + change_neo4j_password, + clear_neo4j_database, + drop_constraints, + drop_indexes, + install_all_labels, + install_labels, + remove_all_labels, ) from neomodel.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne from neomodel.exceptions import * diff --git a/neomodel/_async/core.py b/neomodel/_async/core.py index 606125b6..7225816f 100644 --- a/neomodel/_async/core.py +++ b/neomodel/_async/core.py @@ -60,9 +60,9 @@ async def wrapper(self, *args, **kwargs): if not _db.driver: if hasattr(config, "DRIVER") and config.DRIVER: - await _db.set_connection_async(driver=config.DRIVER) + await _db.set_connection(driver=config.DRIVER) elif config.DATABASE_URL: - await _db.set_connection_async(url=config.DATABASE_URL) + await _db.set_connection(url=config.DATABASE_URL) return await func(self, *args, **kwargs) @@ -88,7 +88,7 @@ def __init__(self): self._database_edition = None self.impersonated_user = None - async def set_connection_async(self, url: str = None, driver: AsyncDriver = None): + async def set_connection(self, url: str = None, driver: AsyncDriver = None): """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -115,7 +115,7 @@ async def set_connection_async(self, url: str = None, driver: AsyncDriver = None # Getting the information about the database version requires a connection to the database self._database_version = None self._database_edition = None - await self._update_database_version_async() + await self._update_database_version() def _parse_driver_from_url(self, url: str) -> None: """Parse the driver information from the given URL and initialize the driver. @@ -182,7 +182,7 @@ def _parse_driver_from_url(self, url: str) -> None: else: self._database_name = database_name - async def close_connection_async(self): + async def close_connection(self): """ Closes the currently open driver. The driver should always be closed at the end of the application's lifecyle. @@ -196,14 +196,14 @@ async def close_connection_async(self): @property def database_version(self): if self._database_version is None: - self._update_database_version_async() + self._update_database_version() return self._database_version @property def database_edition(self): if self._database_edition is None: - self._update_database_version_async() + self._update_database_version() return self._database_edition @@ -238,7 +238,7 @@ def impersonate(self, user: str) -> "ImpersonationHandler": return ImpersonationHandler(self, impersonated_user=user) @ensure_connection - async def begin_async(self, access_mode=None, **parameters): + async def begin(self, access_mode=None, **parameters): """ Begins a new transaction. Raises SystemError if a transaction is already active. """ @@ -258,7 +258,7 @@ async def begin_async(self, access_mode=None, **parameters): ) @ensure_connection - async def commit_async(self): + async def commit(self): """ Commits the current transaction and closes its session @@ -279,7 +279,7 @@ async def commit_async(self): return last_bookmarks @ensure_connection - async def rollback_async(self): + async def rollback(self): """ Rolls back the current transaction and closes its session """ @@ -293,12 +293,12 @@ async def rollback_async(self): self._active_transaction = None self._session = None - async def _update_database_version_async(self): + async def _update_database_version(self): """ Updates the database server information when it is required """ try: - results = await self.cypher_query_async( + results = await self.cypher_query( "CALL dbms.components() yield versions, edition return versions[0], edition" ) self._database_version = results[0][0][0] @@ -399,7 +399,7 @@ def _result_resolution(self, result_list): return result_list @ensure_connection - async def cypher_query_async( + async def cypher_query( self, query, params=None, @@ -427,7 +427,7 @@ async def cypher_query_async( if self._active_transaction: # Use current session is a transaction is currently active - results, meta = await self._run_cypher_query_async( + results, meta = await self._run_cypher_query( self._active_transaction, query, params, @@ -440,7 +440,7 @@ async def cypher_query_async( async with self.driver.session( database=self._database_name, impersonated_user=self.impersonated_user ) as session: - results, meta = await self._run_cypher_query_async( + results, meta = await self._run_cypher_query( session, query, params, @@ -451,7 +451,7 @@ async def cypher_query_async( return results, meta - async def _run_cypher_query_async( + async def _run_cypher_query( self, session: AsyncSession, query, @@ -481,8 +481,8 @@ async def _run_cypher_query_async( raise exc_info[1].with_traceback(exc_info[2]) except SessionExpired: if retry_on_session_expire: - await self.set_connection_async(url=self.url) - return await self.cypher_query_async( + await self.set_connection(url=self.url) + return await self.cypher_query( query=query, params=params, handle_unique=handle_unique, @@ -510,7 +510,7 @@ def get_id_method(self) -> str: else: return "elementId" - async def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict]: + async def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: """Returns all indexes existing in the database Arguments: @@ -519,7 +519,7 @@ async def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict] Returns: Sequence[dict]: List of dictionaries, each entry being an index definition """ - indexes, meta_indexes = await self.cypher_query_async("SHOW INDEXES") + indexes, meta_indexes = await self.cypher_query("SHOW INDEXES") indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] if exclude_token_lookup: @@ -529,15 +529,13 @@ async def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict] return indexes_as_dict - async def list_constraints_async(self) -> Sequence[dict]: + async def list_constraints(self) -> Sequence[dict]: """Returns all constraints existing in the database Returns: Sequence[dict]: List of dictionaries, each entry being a constraint definition """ - constraints, meta_constraints = await self.cypher_query_async( - "SHOW CONSTRAINTS" - ) + constraints, meta_constraints = await self.cypher_query("SHOW CONSTRAINTS") constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] return constraints_as_dict @@ -563,15 +561,11 @@ def edition_is_enterprise(self) -> bool: """ return self.database_edition == "enterprise" - async def change_neo4j_password_async(self, user, new_password): - await self.cypher_query_async( - f"ALTER USER {user} SET PASSWORD '{new_password}'" - ) + async def change_neo4j_password(self, user, new_password): + await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") - async def clear_neo4j_database_async( - self, clear_constraints=False, clear_indexes=False - ): - await self.cypher_query_async( + async def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): + await self.cypher_query( """ MATCH (a) CALL { WITH a DETACH DELETE a } @@ -579,11 +573,11 @@ async def clear_neo4j_database_async( """ ) if clear_constraints: - await drop_constraints_async() + await drop_constraints() if clear_indexes: - await drop_indexes_async() + await drop_indexes() - async def drop_constraints_async(self, quiet=True, stdout=None): + async def drop_constraints(self, quiet=True, stdout=None): """ Discover and drop all constraints. @@ -593,11 +587,11 @@ async def drop_constraints_async(self, quiet=True, stdout=None): if not stdout or stdout is None: stdout = sys.stdout - results, meta = await self.cypher_query_async("SHOW CONSTRAINTS") + results, meta = await self.cypher_query("SHOW CONSTRAINTS") results_as_dict = [dict(zip(meta, row)) for row in results] for constraint in results_as_dict: - await self.cypher_query_async("DROP CONSTRAINT " + constraint["name"]) + await self.cypher_query("DROP CONSTRAINT " + constraint["name"]) if not quiet: stdout.write( ( @@ -609,7 +603,7 @@ async def drop_constraints_async(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - async def drop_indexes_async(self, quiet=True, stdout=None): + async def drop_indexes(self, quiet=True, stdout=None): """ Discover and drop all indexes, except the automatically created token lookup indexes. @@ -619,9 +613,9 @@ async def drop_indexes_async(self, quiet=True, stdout=None): if not stdout or stdout is None: stdout = sys.stdout - indexes = await self.list_indexes_async(exclude_token_lookup=True) + indexes = await self.list_indexes(exclude_token_lookup=True) for index in indexes: - await self.cypher_query_async("DROP INDEX " + index["name"]) + await self.cypher_query("DROP INDEX " + index["name"]) if not quiet: stdout.write( f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' @@ -629,7 +623,7 @@ async def drop_indexes_async(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - async def remove_all_labels_async(self, stdout=None): + async def remove_all_labels(self, stdout=None): """ Calls functions for dropping constraints and indexes. @@ -641,12 +635,12 @@ async def remove_all_labels_async(self, stdout=None): stdout = sys.stdout stdout.write("Dropping constraints...\n") - await self.drop_constraints_async(quiet=False, stdout=stdout) + await self.drop_constraints(quiet=False, stdout=stdout) stdout.write("Dropping indexes...\n") - await self.drop_indexes_async(quiet=False, stdout=stdout) + await self.drop_indexes(quiet=False, stdout=stdout) - async def install_all_labels_async(self, stdout=None): + async def install_all_labels(self, stdout=None): """ Discover all subclasses of StructuredNode in your application and execute install_labels on each. Note: code must be loaded (imported) in order for a class to be discovered. @@ -669,7 +663,7 @@ def subsub(cls): # recursively return all subclasses i = 0 for cls in subsub(StructuredNodeAsync): stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") - await install_labels_async(cls, quiet=False, stdout=stdout) + await install_labels(cls, quiet=False, stdout=stdout) i += 1 if i: @@ -677,7 +671,7 @@ def subsub(cls): # recursively return all subclasses stdout.write(f"Finished {i} classes.\n") - async def install_labels_async(self, cls, quiet=True, stdout=None): + async def install_labels(self, cls, quiet=True, stdout=None): """ Setup labels with indexes and constraints for a given class @@ -699,16 +693,16 @@ async def install_labels_async(self, cls, quiet=True, stdout=None): return for name, property in cls.defined_properties(aliases=False, rels=False).items(): - await self._install_node_async(cls, name, property, quiet, stdout) + await self._install_node(cls, name, property, quiet, stdout) for _, relationship in cls.defined_properties( aliases=False, rels=True, properties=False ).items(): - await self._install_relationship_async(cls, relationship, quiet, stdout) + await self._install_relationship(cls, relationship, quiet, stdout) - async def _create_node_index_async(self, label: str, property_name: str, stdout): + async def _create_node_index(self, label: str, property_name: str, stdout): try: - await self.cypher_query_async( + await self.cypher_query( f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " ) except ClientError as e: @@ -720,11 +714,9 @@ async def _create_node_index_async(self, label: str, property_name: str, stdout) else: raise - async def _create_node_constraint_async( - self, label: str, property_name: str, stdout - ): + async def _create_node_constraint(self, label: str, property_name: str, stdout): try: - await self.cypher_query_async( + await self.cypher_query( f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" ) @@ -737,11 +729,11 @@ async def _create_node_constraint_async( else: raise - async def _create_relationship_index_async( + async def _create_relationship_index( self, relationship_type: str, property_name: str, stdout ): try: - await self.cypher_query_async( + await self.cypher_query( f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " ) except ClientError as e: @@ -753,12 +745,12 @@ async def _create_relationship_index_async( else: raise - async def _create_relationship_constraint_async( + async def _create_relationship_constraint( self, relationship_type: str, property_name: str, stdout ): if self.version_is_higher_than("5.7"): try: - await self.cypher_query_async( + await self.cypher_query( f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" ) @@ -775,7 +767,7 @@ async def _create_relationship_constraint_async( f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." ) - async def _install_node_async(self, cls, name, property, quiet, stdout): + async def _install_node(self, cls, name, property, quiet, stdout): # Create indexes and constraints for node property db_property = property.db_property or name if property.index: @@ -783,7 +775,7 @@ async def _install_node_async(self, cls, name, property, quiet, stdout): stdout.write( f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" ) - await self._create_node_index_async( + await self._create_node_index( label=cls.__label__, property_name=db_property, stdout=stdout ) @@ -792,11 +784,11 @@ async def _install_node_async(self, cls, name, property, quiet, stdout): stdout.write( f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" ) - await self._create_node_constraint_async( + await self._create_node_constraint( label=cls.__label__, property_name=db_property, stdout=stdout ) - async def _install_relationship_async(self, cls, relationship, quiet, stdout): + async def _install_relationship(self, cls, relationship, quiet, stdout): # Create indexes and constraints for relationship property relationship_cls = relationship.definition["model"] if relationship_cls is not None: @@ -810,7 +802,7 @@ async def _install_relationship_async(self, cls, relationship, quiet, stdout): stdout.write( f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" ) - await self._create_relationship_index_async( + await self._create_relationship_index( relationship_type=relationship_type, property_name=db_property, stdout=stdout, @@ -820,7 +812,7 @@ async def _install_relationship_async(self, cls, relationship, quiet, stdout): stdout.write( f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" ) - await self._create_relationship_constraint_async( + await self._create_relationship_constraint( relationship_type=relationship_type, property_name=db_property, stdout=stdout, @@ -832,81 +824,83 @@ async def _install_relationship_async(self, cls, relationship, quiet, stdout): # Deprecated methods -async def change_neo4j_password_async(db, user, new_password): +async def change_neo4j_password(db: AsyncDatabase, user, new_password): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.change_neo4j_password_async(user, new_password) instead. + Please use adb.change_neo4j_password(user, new_password) instead. This direct call will be removed in an upcoming version. """ ) - await db.change_neo4j_password_async(user, new_password) + await db.change_neo4j_password(user, new_password) -async def clear_neo4j_database_async(db, clear_constraints=False, clear_indexes=False): +async def clear_neo4j_database( + db: AsyncDatabase, clear_constraints=False, clear_indexes=False +): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.clear_neo4j_database_async(clear_constraints, clear_indexes) instead. + Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead. This direct call will be removed in an upcoming version. """ ) - await db.clear_neo4j_database_async(clear_constraints, clear_indexes) + await db.clear_neo4j_database(clear_constraints, clear_indexes) -async def drop_constraints_async(quiet=True, stdout=None): +async def drop_constraints(quiet=True, stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_constraints_async(quiet, stdout) instead. + Please use adb.drop_constraints(quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - await adb.drop_constraints_async(quiet, stdout) + await adb.drop_constraints(quiet, stdout) -async def drop_indexes_async(quiet=True, stdout=None): +async def drop_indexes(quiet=True, stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_indexes_async(quiet, stdout) instead. + Please use adb.drop_indexes(quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - await adb.drop_indexes_async(quiet, stdout) + await adb.drop_indexes(quiet, stdout) -async def remove_all_labels_async(stdout=None): +async def remove_all_labels(stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.remove_all_labels_async(stdout) instead. + Please use adb.remove_all_labels(stdout) instead. This direct call will be removed in an upcoming version. """ ) - await adb.remove_all_labels_async(stdout) + await adb.remove_all_labels(stdout) -async def install_labels_async(cls, quiet=True, stdout=None): +async def install_labels(cls, quiet=True, stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_labels_async(cls, quiet, stdout) instead. + Please use adb.install_labels(cls, quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - await adb.install_labels_async(cls, quiet, stdout) + await adb.install_labels(cls, quiet, stdout) -async def install_all_labels_async(stdout=None): +async def install_all_labels(stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_all_labels_async(stdout) instead. + Please use adb.install_all_labels(stdout) instead. This direct call will be removed in an upcoming version. """ ) - await adb.install_all_labels_async(stdout) + await adb.install_all_labels(stdout) class TransactionProxyAsync: @@ -918,15 +912,13 @@ def __init__(self, db: AsyncDatabase, access_mode=None): @ensure_connection async def __enter__(self): - await self.db.begin_async( - access_mode=self.access_mode, bookmarks=self.bookmarks - ) + await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self async def __exit__(self, exc_type, exc_value, traceback): if exc_value: - await self.db.rollback_async() + await self.db.rollback() if ( exc_type is ClientError @@ -935,7 +927,7 @@ async def __exit__(self, exc_type, exc_value, traceback): raise UniqueProperty(exc_value.message) if not exc_value: - self.last_bookmark = await self.db.commit_async() + self.last_bookmark = await self.db.commit() def __call__(self, func): def wrapper(*args, **kwargs): @@ -1046,7 +1038,7 @@ def __new__(mcs, name, bases, namespace): # TODO : See previous TODO comment # if config.AUTO_INSTALL_LABELS: - # await install_labels_async(cls, quiet=False) + # await install_labels(cls, quiet=False) build_class_registry(cls) @@ -1214,7 +1206,7 @@ def _build_merge_query( return query, query_params @classmethod - async def create_async(cls, *props, **kwargs): + async def create(cls, *props, **kwargs): """ Call to CREATE with parameters map. A new instance will be created and saved. @@ -1246,7 +1238,7 @@ async def create_async(cls, *props, **kwargs): for item in [ cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props ]: - node, _ = await adb.cypher_query_async(query, {"create_params": item}) + node, _ = await adb.cypher_query(query, {"create_params": item}) results.extend(node[0]) nodes = [cls.inflate(node) for node in results] @@ -1258,7 +1250,7 @@ async def create_async(cls, *props, **kwargs): return nodes @classmethod - async def create_or_update_async(cls, *props, **kwargs): + async def create_or_update(cls, *props, **kwargs): """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, this is an atomic operation. If an instance already exists all optional properties specified will be updated. @@ -1302,10 +1294,10 @@ async def create_or_update_async(cls, *props, **kwargs): ) # fetch and build instance for each result - results = await adb.cypher_query_async(query, params) + results = await adb.cypher_query(query, params) return [cls.inflate(r[0]) async for r in results[0]] - async def cypher_async(self, query, params=None): + async def cypher(self, query, params=None): """ Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. @@ -1319,17 +1311,17 @@ async def cypher_async(self, query, params=None): self._pre_action_check("cypher") params = params or {} params.update({"self": self.element_id}) - return await adb.cypher_query_async(query, params) + return await adb.cypher_query(query, params) @hooks - async def delete_async(self): + async def delete(self): """ Delete a node and its relationships :return: True """ self._pre_action_check("delete") - await self.cypher_async( + await self.cypher( f"MATCH (self) WHERE {adb.get_id_method()}(self)=$self DETACH DELETE self" ) delattr(self, "element_id_property") @@ -1337,7 +1329,7 @@ async def delete_async(self): return True @classmethod - async def get_or_create_async(cls, *props, **kwargs): + async def get_or_create(cls, *props, **kwargs): """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, this is an atomic operation. @@ -1371,7 +1363,7 @@ async def get_or_create_async(cls, *props, **kwargs): ) # fetch and build instance for each result - results = await adb.cypher_query_async(query, params) + results = await adb.cypher_query(query, params) return [cls.inflate(r[0]) async for r in results[0]] @classmethod @@ -1432,7 +1424,7 @@ def inherited_optional_labels(cls): if not hasattr(scls, "__abstract_node__") ] - async def labels_async(self): + async def labels(self): """ Returns list of labels tied to the node from neo4j. @@ -1440,7 +1432,7 @@ async def labels_async(self): :rtype: list """ self._pre_action_check("labels") - return await self.cypher_async( + return await self.cypher( f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" )[0][0][0] @@ -1454,13 +1446,13 @@ def _pre_action_check(self, action): f"{self.__class__.__name__}.{action}() attempted on unsaved node" ) - async def refresh_async(self): + async def refresh(self): """ Reload the node from neo4j """ self._pre_action_check("refresh") if hasattr(self, "element_id"): - request = await self.cypher_async( + request = await self.cypher( f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" )[0] if not request or not request[0]: @@ -1472,7 +1464,7 @@ async def refresh_async(self): raise ValueError("Can't refresh unsaved node") @hooks - async def save_async(self): + async def save(self): """ Save the node to neo4j or raise an exception @@ -1493,13 +1485,13 @@ async def save_async(self): query += "\n".join( [f"SET n:`{label}`" for label in self.inherited_labels()] ) - await self.cypher_async(query, params) + await self.cypher(query, params) elif hasattr(self, "deleted") and self.deleted: raise ValueError( f"{self.__class__.__name__}.save() attempted on deleted node" ) else: # create - result = await self.create_async(self.__properties__) + result = await self.create(self.__properties__) created_node = result[0] self.element_id_property = created_node.element_id return self diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py index e4a9990b..7b408705 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/_sync/core.py @@ -60,9 +60,9 @@ def wrapper(self, *args, **kwargs): if not _db.driver: if hasattr(config, "DRIVER") and config.DRIVER: - _db.set_connection_async(driver=config.DRIVER) + _db.set_connection(driver=config.DRIVER) elif config.DATABASE_URL: - _db.set_connection_async(url=config.DATABASE_URL) + _db.set_connection(url=config.DATABASE_URL) return func(self, *args, **kwargs) @@ -88,7 +88,7 @@ def __init__(self): self._database_edition = None self.impersonated_user = None - def set_connection_async(self, url: str = None, driver: Driver = None): + def set_connection(self, url: str = None, driver: Driver = None): """ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. @@ -115,7 +115,7 @@ def set_connection_async(self, url: str = None, driver: Driver = None): # Getting the information about the database version requires a connection to the database self._database_version = None self._database_edition = None - self._update_database_version_async() + self._update_database_version() def _parse_driver_from_url(self, url: str) -> None: """Parse the driver information from the given URL and initialize the driver. @@ -182,7 +182,7 @@ def _parse_driver_from_url(self, url: str) -> None: else: self._database_name = database_name - def close_connection_async(self): + def close_connection(self): """ Closes the currently open driver. The driver should always be closed at the end of the application's lifecyle. @@ -196,14 +196,14 @@ def close_connection_async(self): @property def database_version(self): if self._database_version is None: - self._update_database_version_async() + self._update_database_version() return self._database_version @property def database_edition(self): if self._database_edition is None: - self._update_database_version_async() + self._update_database_version() return self._database_edition @@ -238,7 +238,7 @@ def impersonate(self, user: str) -> "ImpersonationHandler": return ImpersonationHandler(self, impersonated_user=user) @ensure_connection - def begin_async(self, access_mode=None, **parameters): + def begin(self, access_mode=None, **parameters): """ Begins a new transaction. Raises SystemError if a transaction is already active. """ @@ -253,10 +253,12 @@ def begin_async(self, access_mode=None, **parameters): impersonated_user=self.impersonated_user, **parameters, ) - self._active_transaction: Transaction = self._session.begin_transaction() + self._active_transaction: Transaction = ( + self._session.begin_transaction() + ) @ensure_connection - def commit_async(self): + def commit(self): """ Commits the current transaction and closes its session @@ -277,7 +279,7 @@ def commit_async(self): return last_bookmarks @ensure_connection - def rollback_async(self): + def rollback(self): """ Rolls back the current transaction and closes its session """ @@ -291,12 +293,12 @@ def rollback_async(self): self._active_transaction = None self._session = None - def _update_database_version_async(self): + def _update_database_version(self): """ Updates the database server information when it is required """ try: - results = self.cypher_query_async( + results = self.cypher_query( "CALL dbms.components() yield versions, edition return versions[0], edition" ) self._database_version = results[0][0][0] @@ -397,7 +399,7 @@ def _result_resolution(self, result_list): return result_list @ensure_connection - def cypher_query_async( + def cypher_query( self, query, params=None, @@ -425,7 +427,7 @@ def cypher_query_async( if self._active_transaction: # Use current session is a transaction is currently active - results, meta = self._run_cypher_query_async( + results, meta = self._run_cypher_query( self._active_transaction, query, params, @@ -438,7 +440,7 @@ def cypher_query_async( with self.driver.session( database=self._database_name, impersonated_user=self.impersonated_user ) as session: - results, meta = self._run_cypher_query_async( + results, meta = self._run_cypher_query( session, query, params, @@ -449,7 +451,7 @@ def cypher_query_async( return results, meta - def _run_cypher_query_async( + def _run_cypher_query( self, session: Session, query, @@ -479,8 +481,8 @@ def _run_cypher_query_async( raise exc_info[1].with_traceback(exc_info[2]) except SessionExpired: if retry_on_session_expire: - self.set_connection_async(url=self.url) - return self.cypher_query_async( + self.set_connection(url=self.url) + return self.cypher_query( query=query, params=params, handle_unique=handle_unique, @@ -508,7 +510,7 @@ def get_id_method(self) -> str: else: return "elementId" - def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict]: + def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: """Returns all indexes existing in the database Arguments: @@ -517,7 +519,7 @@ def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict]: Returns: Sequence[dict]: List of dictionaries, each entry being an index definition """ - indexes, meta_indexes = self.cypher_query_async("SHOW INDEXES") + indexes, meta_indexes = self.cypher_query("SHOW INDEXES") indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] if exclude_token_lookup: @@ -527,13 +529,13 @@ def list_indexes_async(self, exclude_token_lookup=False) -> Sequence[dict]: return indexes_as_dict - def list_constraints_async(self) -> Sequence[dict]: + def list_constraints(self) -> Sequence[dict]: """Returns all constraints existing in the database Returns: Sequence[dict]: List of dictionaries, each entry being a constraint definition """ - constraints, meta_constraints = self.cypher_query_async("SHOW CONSTRAINTS") + constraints, meta_constraints = self.cypher_query("SHOW CONSTRAINTS") constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] return constraints_as_dict @@ -559,11 +561,11 @@ def edition_is_enterprise(self) -> bool: """ return self.database_edition == "enterprise" - def change_neo4j_password_async(self, user, new_password): - self.cypher_query_async(f"ALTER USER {user} SET PASSWORD '{new_password}'") + def change_neo4j_password(self, user, new_password): + self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") - def clear_neo4j_database_async(self, clear_constraints=False, clear_indexes=False): - self.cypher_query_async( + def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): + self.cypher_query( """ MATCH (a) CALL { WITH a DETACH DELETE a } @@ -571,11 +573,11 @@ def clear_neo4j_database_async(self, clear_constraints=False, clear_indexes=Fals """ ) if clear_constraints: - drop_constraints_async() + drop_constraints() if clear_indexes: - drop_indexes_async() + drop_indexes() - def drop_constraints_async(self, quiet=True, stdout=None): + def drop_constraints(self, quiet=True, stdout=None): """ Discover and drop all constraints. @@ -585,11 +587,11 @@ def drop_constraints_async(self, quiet=True, stdout=None): if not stdout or stdout is None: stdout = sys.stdout - results, meta = self.cypher_query_async("SHOW CONSTRAINTS") + results, meta = self.cypher_query("SHOW CONSTRAINTS") results_as_dict = [dict(zip(meta, row)) for row in results] for constraint in results_as_dict: - self.cypher_query_async("DROP CONSTRAINT " + constraint["name"]) + self.cypher_query("DROP CONSTRAINT " + constraint["name"]) if not quiet: stdout.write( ( @@ -601,7 +603,7 @@ def drop_constraints_async(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - def drop_indexes_async(self, quiet=True, stdout=None): + def drop_indexes(self, quiet=True, stdout=None): """ Discover and drop all indexes, except the automatically created token lookup indexes. @@ -611,9 +613,9 @@ def drop_indexes_async(self, quiet=True, stdout=None): if not stdout or stdout is None: stdout = sys.stdout - indexes = self.list_indexes_async(exclude_token_lookup=True) + indexes = self.list_indexes(exclude_token_lookup=True) for index in indexes: - self.cypher_query_async("DROP INDEX " + index["name"]) + self.cypher_query("DROP INDEX " + index["name"]) if not quiet: stdout.write( f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' @@ -621,7 +623,7 @@ def drop_indexes_async(self, quiet=True, stdout=None): if not quiet: stdout.write("\n") - def remove_all_labels_async(self, stdout=None): + def remove_all_labels(self, stdout=None): """ Calls functions for dropping constraints and indexes. @@ -633,12 +635,12 @@ def remove_all_labels_async(self, stdout=None): stdout = sys.stdout stdout.write("Dropping constraints...\n") - self.drop_constraints_async(quiet=False, stdout=stdout) + self.drop_constraints(quiet=False, stdout=stdout) stdout.write("Dropping indexes...\n") - self.drop_indexes_async(quiet=False, stdout=stdout) + self.drop_indexes(quiet=False, stdout=stdout) - def install_all_labels_async(self, stdout=None): + def install_all_labels(self, stdout=None): """ Discover all subclasses of StructuredNode in your application and execute install_labels on each. Note: code must be loaded (imported) in order for a class to be discovered. @@ -661,7 +663,7 @@ def subsub(cls): # recursively return all subclasses i = 0 for cls in subsub(StructuredNodeAsync): stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") - install_labels_async(cls, quiet=False, stdout=stdout) + install_labels(cls, quiet=False, stdout=stdout) i += 1 if i: @@ -669,7 +671,7 @@ def subsub(cls): # recursively return all subclasses stdout.write(f"Finished {i} classes.\n") - def install_labels_async(self, cls, quiet=True, stdout=None): + def install_labels(self, cls, quiet=True, stdout=None): """ Setup labels with indexes and constraints for a given class @@ -691,16 +693,16 @@ def install_labels_async(self, cls, quiet=True, stdout=None): return for name, property in cls.defined_properties(aliases=False, rels=False).items(): - self._install_node_async(cls, name, property, quiet, stdout) + self._install_node(cls, name, property, quiet, stdout) for _, relationship in cls.defined_properties( aliases=False, rels=True, properties=False ).items(): - self._install_relationship_async(cls, relationship, quiet, stdout) + self._install_relationship(cls, relationship, quiet, stdout) - def _create_node_index_async(self, label: str, property_name: str, stdout): + def _create_node_index(self, label: str, property_name: str, stdout): try: - self.cypher_query_async( + self.cypher_query( f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " ) except ClientError as e: @@ -712,9 +714,9 @@ def _create_node_index_async(self, label: str, property_name: str, stdout): else: raise - def _create_node_constraint_async(self, label: str, property_name: str, stdout): + def _create_node_constraint(self, label: str, property_name: str, stdout): try: - self.cypher_query_async( + self.cypher_query( f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" ) @@ -727,11 +729,11 @@ def _create_node_constraint_async(self, label: str, property_name: str, stdout): else: raise - def _create_relationship_index_async( + def _create_relationship_index( self, relationship_type: str, property_name: str, stdout ): try: - self.cypher_query_async( + self.cypher_query( f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " ) except ClientError as e: @@ -743,12 +745,12 @@ def _create_relationship_index_async( else: raise - def _create_relationship_constraint_async( + def _create_relationship_constraint( self, relationship_type: str, property_name: str, stdout ): if self.version_is_higher_than("5.7"): try: - self.cypher_query_async( + self.cypher_query( f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" ) @@ -765,7 +767,7 @@ def _create_relationship_constraint_async( f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." ) - def _install_node_async(self, cls, name, property, quiet, stdout): + def _install_node(self, cls, name, property, quiet, stdout): # Create indexes and constraints for node property db_property = property.db_property or name if property.index: @@ -773,7 +775,7 @@ def _install_node_async(self, cls, name, property, quiet, stdout): stdout.write( f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" ) - self._create_node_index_async( + self._create_node_index( label=cls.__label__, property_name=db_property, stdout=stdout ) @@ -782,11 +784,11 @@ def _install_node_async(self, cls, name, property, quiet, stdout): stdout.write( f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" ) - self._create_node_constraint_async( + self._create_node_constraint( label=cls.__label__, property_name=db_property, stdout=stdout ) - def _install_relationship_async(self, cls, relationship, quiet, stdout): + def _install_relationship(self, cls, relationship, quiet, stdout): # Create indexes and constraints for relationship property relationship_cls = relationship.definition["model"] if relationship_cls is not None: @@ -800,7 +802,7 @@ def _install_relationship_async(self, cls, relationship, quiet, stdout): stdout.write( f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" ) - self._create_relationship_index_async( + self._create_relationship_index( relationship_type=relationship_type, property_name=db_property, stdout=stdout, @@ -810,7 +812,7 @@ def _install_relationship_async(self, cls, relationship, quiet, stdout): stdout.write( f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" ) - self._create_relationship_constraint_async( + self._create_relationship_constraint( relationship_type=relationship_type, property_name=db_property, stdout=stdout, @@ -822,81 +824,83 @@ def _install_relationship_async(self, cls, relationship, quiet, stdout): # Deprecated methods -def change_neo4j_password_async(db, user, new_password): +def change_neo4j_password(db: Database, user, new_password): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.change_neo4j_password_async(user, new_password) instead. + Please use adb.change_neo4j_password(user, new_password) instead. This direct call will be removed in an upcoming version. """ ) - db.change_neo4j_password_async(user, new_password) + db.change_neo4j_password(user, new_password) -def clear_neo4j_database_async(db, clear_constraints=False, clear_indexes=False): +def clear_neo4j_database( + db: Database, clear_constraints=False, clear_indexes=False +): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.clear_neo4j_database_async(clear_constraints, clear_indexes) instead. + Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead. This direct call will be removed in an upcoming version. """ ) - db.clear_neo4j_database_async(clear_constraints, clear_indexes) + db.clear_neo4j_database(clear_constraints, clear_indexes) -def drop_constraints_async(quiet=True, stdout=None): +def drop_constraints(quiet=True, stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_constraints_async(quiet, stdout) instead. + Please use adb.drop_constraints(quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.drop_constraints_async(quiet, stdout) + adb.drop_constraints(quiet, stdout) -def drop_indexes_async(quiet=True, stdout=None): +def drop_indexes(quiet=True, stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_indexes_async(quiet, stdout) instead. + Please use adb.drop_indexes(quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.drop_indexes_async(quiet, stdout) + adb.drop_indexes(quiet, stdout) -def remove_all_labels_async(stdout=None): +def remove_all_labels(stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.remove_all_labels_async(stdout) instead. + Please use adb.remove_all_labels(stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.remove_all_labels_async(stdout) + adb.remove_all_labels(stdout) -def install_labels_async(cls, quiet=True, stdout=None): +def install_labels(cls, quiet=True, stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_labels_async(cls, quiet, stdout) instead. + Please use adb.install_labels(cls, quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.install_labels_async(cls, quiet, stdout) + adb.install_labels(cls, quiet, stdout) -def install_all_labels_async(stdout=None): +def install_all_labels(stdout=None): deprecated( """ This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_all_labels_async(stdout) instead. + Please use adb.install_all_labels(stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.install_all_labels_async(stdout) + adb.install_all_labels(stdout) class TransactionProxyAsync: @@ -908,13 +912,13 @@ def __init__(self, db: Database, access_mode=None): @ensure_connection def __enter__(self): - self.db.begin_async(access_mode=self.access_mode, bookmarks=self.bookmarks) + self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self def __exit__(self, exc_type, exc_value, traceback): if exc_value: - self.db.rollback_async() + self.db.rollback() if ( exc_type is ClientError @@ -923,7 +927,7 @@ def __exit__(self, exc_type, exc_value, traceback): raise UniqueProperty(exc_value.message) if not exc_value: - self.last_bookmark = self.db.commit_async() + self.last_bookmark = self.db.commit() def __call__(self, func): def wrapper(*args, **kwargs): @@ -1034,7 +1038,7 @@ def __new__(mcs, name, bases, namespace): # TODO : See previous TODO comment # if config.AUTO_INSTALL_LABELS: - # await install_labels_async(cls, quiet=False) + # await install_labels(cls, quiet=False) build_class_registry(cls) @@ -1202,7 +1206,7 @@ def _build_merge_query( return query, query_params @classmethod - def create_async(cls, *props, **kwargs): + def create(cls, *props, **kwargs): """ Call to CREATE with parameters map. A new instance will be created and saved. @@ -1234,7 +1238,7 @@ def create_async(cls, *props, **kwargs): for item in [ cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props ]: - node, _ = adb.cypher_query_async(query, {"create_params": item}) + node, _ = adb.cypher_query(query, {"create_params": item}) results.extend(node[0]) nodes = [cls.inflate(node) for node in results] @@ -1246,7 +1250,7 @@ def create_async(cls, *props, **kwargs): return nodes @classmethod - def create_or_update_async(cls, *props, **kwargs): + def create_or_update(cls, *props, **kwargs): """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, this is an atomic operation. If an instance already exists all optional properties specified will be updated. @@ -1290,10 +1294,10 @@ def create_or_update_async(cls, *props, **kwargs): ) # fetch and build instance for each result - results = adb.cypher_query_async(query, params) + results = adb.cypher_query(query, params) return [cls.inflate(r[0]) for r in results[0]] - def cypher_async(self, query, params=None): + def cypher(self, query, params=None): """ Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. @@ -1307,17 +1311,17 @@ def cypher_async(self, query, params=None): self._pre_action_check("cypher") params = params or {} params.update({"self": self.element_id}) - return adb.cypher_query_async(query, params) + return adb.cypher_query(query, params) @hooks - def delete_async(self): + def delete(self): """ Delete a node and its relationships :return: True """ self._pre_action_check("delete") - self.cypher_async( + self.cypher( f"MATCH (self) WHERE {adb.get_id_method()}(self)=$self DETACH DELETE self" ) delattr(self, "element_id_property") @@ -1325,7 +1329,7 @@ def delete_async(self): return True @classmethod - def get_or_create_async(cls, *props, **kwargs): + def get_or_create(cls, *props, **kwargs): """ Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, this is an atomic operation. @@ -1359,7 +1363,7 @@ def get_or_create_async(cls, *props, **kwargs): ) # fetch and build instance for each result - results = adb.cypher_query_async(query, params) + results = adb.cypher_query(query, params) return [cls.inflate(r[0]) for r in results[0]] @classmethod @@ -1420,7 +1424,7 @@ def inherited_optional_labels(cls): if not hasattr(scls, "__abstract_node__") ] - def labels_async(self): + def labels(self): """ Returns list of labels tied to the node from neo4j. @@ -1428,7 +1432,7 @@ def labels_async(self): :rtype: list """ self._pre_action_check("labels") - return self.cypher_async( + return self.cypher( f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" )[0][0][0] @@ -1442,13 +1446,13 @@ def _pre_action_check(self, action): f"{self.__class__.__name__}.{action}() attempted on unsaved node" ) - def refresh_async(self): + def refresh(self): """ Reload the node from neo4j """ self._pre_action_check("refresh") if hasattr(self, "element_id"): - request = self.cypher_async( + request = self.cypher( f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" )[0] if not request or not request[0]: @@ -1460,7 +1464,7 @@ def refresh_async(self): raise ValueError("Can't refresh unsaved node") @hooks - def save_async(self): + def save(self): """ Save the node to neo4j or raise an exception @@ -1481,13 +1485,13 @@ def save_async(self): query += "\n".join( [f"SET n:`{label}`" for label in self.inherited_labels()] ) - self.cypher_async(query, params) + self.cypher(query, params) elif hasattr(self, "deleted") and self.deleted: raise ValueError( f"{self.__class__.__name__}.save() attempted on deleted node" ) else: # create - result = self.create_async(self.__properties__) + result = self.create(self.__properties__) created_node = result[0] self.element_id_property = created_node.element_id return self diff --git a/neomodel/match.py b/neomodel/match.py index 815b02ec..122e888f 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -664,7 +664,7 @@ def _count(self): # drop additional_return to avoid unexpected result self._ast.additional_return = None query = self.build_query() - results, _ = adb.cypher_query_async(query, self._query_params) + results, _ = adb.cypher_query(query, self._query_params) return int(results[0][0]) def _contains(self, node_element_id): @@ -691,9 +691,7 @@ def _execute(self, lazy=False): for item in self._ast.additional_return ] query = self.build_query() - results, _ = adb.cypher_query_async( - query, self._query_params, resolve_objects=True - ) + results, _ = adb.cypher_query(query, self._query_params, resolve_objects=True) # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first diff --git a/neomodel/relationship.py b/neomodel/relationship.py index bea17660..d25990f0 100644 --- a/neomodel/relationship.py +++ b/neomodel/relationship.py @@ -114,7 +114,7 @@ def save(self): query += "".join([f" SET r.{key} = ${key}" for key in props]) props["self"] = self.element_id - adb.cypher_query_async(query, props) + adb.cypher_query(query, props) return self @@ -124,7 +124,7 @@ def start_node(self): :return: StructuredNode """ - test = adb.cypher_query_async( + test = adb.cypher_query( f""" MATCH (aNode) WHERE {adb.get_id_method()}(aNode)=$start_node_element_id @@ -141,7 +141,7 @@ def end_node(self): :return: StructuredNode """ - return adb.cypher_query_async( + return adb.cypher_query( f""" MATCH (aNode) WHERE {adb.get_id_method()}(aNode)=$end_node_element_id diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index c994dd0c..aa929bda 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -78,13 +78,13 @@ def get_properties_for_label(label): ORDER BY size(properties) DESC RETURN apoc.meta.cypher.types(properties(sampleNode)) AS properties LIMIT 1 """ - result, _ = adb.cypher_query_async(query) + result, _ = adb.cypher_query(query) if result is not None and len(result) > 0: return result[0][0] @staticmethod def get_constraints_for_label(label): - constraints, meta_constraints = adb.cypher_query_async( + constraints, meta_constraints = adb.cypher_query( f"SHOW CONSTRAINTS WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='UNIQUENESS'" ) constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] @@ -98,11 +98,11 @@ def get_constraints_for_label(label): @staticmethod def get_indexed_properties_for_label(label): if adb.version_is_higher_than("5.0"): - indexes, meta_indexes = adb.cypher_query_async( + indexes, meta_indexes = adb.cypher_query( f"SHOW INDEXES WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='RANGE' AND owningConstraint IS NULL" ) else: - indexes, meta_indexes = adb.cypher_query_async( + indexes, meta_indexes = adb.cypher_query( f"SHOW INDEXES WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='BTREE' AND uniqueness='NONUNIQUE'" ) indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] @@ -123,12 +123,12 @@ def outgoing_relationships(cls, start_label): ORDER BY size(properties) DESC RETURN rel_type, target_label, apoc.meta.cypher.types(properties(sampleRel)) AS properties LIMIT 1 """ - result, _ = adb.cypher_query_async(query) + result, _ = adb.cypher_query(query) return [(record[0], record[1], record[2]) for record in result] @staticmethod def get_constraints_for_type(rel_type): - constraints, meta_constraints = adb.cypher_query_async( + constraints, meta_constraints = adb.cypher_query( f"SHOW CONSTRAINTS WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='RELATIONSHIP_UNIQUENESS'" ) constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] @@ -142,11 +142,11 @@ def get_constraints_for_type(rel_type): @staticmethod def get_indexed_properties_for_type(rel_type): if adb.version_is_higher_than("5.0"): - indexes, meta_indexes = adb.cypher_query_async( + indexes, meta_indexes = adb.cypher_query( f"SHOW INDEXES WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='RANGE' AND owningConstraint IS NULL" ) else: - indexes, meta_indexes = adb.cypher_query_async( + indexes, meta_indexes = adb.cypher_query( f"SHOW INDEXES WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='BTREE' AND uniqueness='NONUNIQUE'" ) indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] @@ -160,7 +160,7 @@ def get_indexed_properties_for_type(rel_type): @staticmethod def infer_cardinality(rel_type, start_label): range_start_query = f"MATCH (n:`{start_label}`) WHERE NOT EXISTS ((n)-[:`{rel_type}`]->()) WITH n LIMIT 1 RETURN count(n)" - result, _ = adb.cypher_query_async(range_start_query) + result, _ = adb.cypher_query(range_start_query) is_start_zero = result[0][0] > 0 range_end_query = f""" @@ -170,7 +170,7 @@ def infer_cardinality(rel_type, start_label): WITH n LIMIT 1 RETURN count(n) """ - result, _ = adb.cypher_query_async(range_end_query) + result, _ = adb.cypher_query(range_end_query) is_end_one = result[0][0] == 0 cardinality = "Zero" if is_start_zero else "One" @@ -184,7 +184,7 @@ def infer_cardinality(rel_type, start_label): def get_node_labels(): query = "CALL db.labels()" - result, _ = adb.cypher_query_async(query) + result, _ = adb.cypher_query(query) return [record[0] for record in result] @@ -268,7 +268,7 @@ def build_rel_type_definition(label, outgoing_relationships, defined_rel_types): def inspect_database(bolt_url): # Connect to the database print(f"Connecting to {bolt_url}") - adb.set_connection_async(bolt_url) + adb.set_connection(bolt_url) node_labels = get_node_labels() defined_rel_types = [] diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index 5b7f65a8..39a59c77 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -109,9 +109,9 @@ def main(): # Connect after to override any code in the module that may set the connection print(f"Connecting to {bolt_url}") - adb.set_connection_async(url=bolt_url) + adb.set_connection(url=bolt_url) - adb.install_all_labels_async() + adb.install_all_labels() if __name__ == "__main__": diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 8eb7273b..25cf25bc 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -61,9 +61,9 @@ def main(): # Connect after to override any code in the module that may set the connection print(f"Connecting to {bolt_url}") - adb.set_connection_async(url=bolt_url) + adb.set_connection(url=bolt_url) - adb.remove_all_labels_async() + adb.remove_all_labels() if __name__ == "__main__": diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 88156e7b..65250261 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -28,7 +28,7 @@ async def setup_neo4j_session(request): config.AUTO_INSTALL_LABELS = True # Clear the database if required - database_is_populated, _ = await adb.cypher_query_async( + database_is_populated, _ = await adb.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) if database_is_populated[0][0] and not request.config.getoption("resetdb"): @@ -36,21 +36,21 @@ async def setup_neo4j_session(request): "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." ) - await adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) + await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) - await adb.cypher_query_async( + await adb.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) if adb.database_edition == "enterprise": - await adb.cypher_query_async("GRANT ROLE publisher TO troygreene") - await adb.cypher_query_async("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + await adb.cypher_query("GRANT ROLE publisher TO troygreene") + await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") @pytest_asyncio.fixture(scope="session", autouse=True) @mark_async_test async def cleanup(): yield - await adb.close_connection_async() + await adb.close_connection() @pytest.fixture(scope="session") diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index 1b7934c9..190b4f9e 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -38,19 +38,19 @@ def mocked_import(name, *args, **kwargs): @mark_async_test -async def test_cypher_async(): +async def test_cypher(): """ test result format is backward compatible with earlier versions of neomodel """ - jim = await User2(email="jim1@test.com").save_async() - data, meta = await jim.cypher_async( + jim = await User2(email="jim1@test.com").save() + data, meta = await jim.cypher( f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self RETURN a.email" ) assert data[0][0] == "jim1@test.com" assert "a.email" in meta - data, meta = await jim.cypher_async( + data, meta = await jim.cypher( f""" MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH (a)<-[:USER2]-(b) @@ -62,11 +62,9 @@ async def test_cypher_async(): @mark_async_test async def test_cypher_syntax_error_async(): - jim = await User2(email="jim1@test.com").save_async() + jim = await User2(email="jim1@test.com").save() try: - await jim.cypher_async( - f"MATCH a WHERE {adb.get_id_method()}(a)={{self}} RETURN xx" - ) + await jim.cypher(f"MATCH a WHERE {adb.get_id_method()}(a)={{self}} RETURN xx") except CypherError as e: assert hasattr(e, "message") assert hasattr(e, "code") @@ -84,21 +82,19 @@ async def test_pandas_not_installed_async(hide_available_pkg): ): from neomodel.integration.pandas import to_dataframe - _ = to_dataframe( - await adb.cypher_query_async("MATCH (a) RETURN a.name AS name") - ) + _ = to_dataframe(await adb.cypher_query("MATCH (a) RETURN a.name AS name")) @mark_async_test async def test_pandas_integration_async(): from neomodel.integration.pandas import to_dataframe, to_series - jimla = await UserPandas(email="jimla@test.com", name="jimla").save_async() - jimlo = await UserPandas(email="jimlo@test.com", name="jimlo").save_async() + jimla = await UserPandas(email="jimla@test.com", name="jimla").save() + jimlo = await UserPandas(email="jimlo@test.com", name="jimlo").save() # Test to_dataframe df = to_dataframe( - await adb.cypher_query_async( + await adb.cypher_query( "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" ) ) @@ -109,7 +105,7 @@ async def test_pandas_integration_async(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( - await adb.cypher_query_async( + await adb.cypher_query( "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" ), index=df["email"], @@ -120,7 +116,7 @@ async def test_pandas_integration_async(): # Next test to_series series = to_series( - await adb.cypher_query_async("MATCH (a:UserPandas) RETURN a.name AS name") + await adb.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") ) assert isinstance(series, Series) @@ -138,20 +134,18 @@ async def test_numpy_not_installed_async(hide_available_pkg): ): from neomodel.integration.numpy import to_ndarray - _ = to_ndarray( - await adb.cypher_query_async("MATCH (a) RETURN a.name AS name") - ) + _ = to_ndarray(await adb.cypher_query("MATCH (a) RETURN a.name AS name")) @mark_async_test async def test_numpy_integration_async(): from neomodel.integration.numpy import to_ndarray - jimly = await UserNP(email="jimly@test.com", name="jimly").save_async() - jimlu = await UserNP(email="jimlu@test.com", name="jimlu").save_async() + jimly = await UserNP(email="jimly@test.com", name="jimly").save() + jimlu = await UserNP(email="jimlu@test.com", name="jimlu").save() array = to_ndarray( - await adb.cypher_query_async( + await adb.cypher_query( "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" ) ) diff --git a/test/sync/conftest.py b/test/sync/conftest.py index 6ccd8d0d..12e1183f 100644 --- a/test/sync/conftest.py +++ b/test/sync/conftest.py @@ -28,7 +28,7 @@ def setup_neo4j_session(request): config.AUTO_INSTALL_LABELS = True # Clear the database if required - database_is_populated, _ = adb.cypher_query_async( + database_is_populated, _ = adb.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) if database_is_populated[0][0] and not request.config.getoption("resetdb"): @@ -36,21 +36,21 @@ def setup_neo4j_session(request): "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." ) - adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) + adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) - adb.cypher_query_async( + adb.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) if adb.database_edition == "enterprise": - adb.cypher_query_async("GRANT ROLE publisher TO troygreene") - adb.cypher_query_async("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + adb.cypher_query("GRANT ROLE publisher TO troygreene") + adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") @pytest_asyncio.fixture(scope="session", autouse=True) @mark_async_test def cleanup(): yield - adb.close_connection_async() + adb.close_connection() @pytest.fixture(scope="session") diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 5bd4afb9..2875d640 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -38,19 +38,19 @@ def mocked_import(name, *args, **kwargs): @mark_async_test -def test_cypher_async(): +def test_cypher(): """ test result format is backward compatible with earlier versions of neomodel """ - jim = User2(email="jim1@test.com").save_async() - data, meta = jim.cypher_async( + jim = User2(email="jim1@test.com").save() + data, meta = jim.cypher( f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self RETURN a.email" ) assert data[0][0] == "jim1@test.com" assert "a.email" in meta - data, meta = jim.cypher_async( + data, meta = jim.cypher( f""" MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH (a)<-[:USER2]-(b) @@ -62,9 +62,9 @@ def test_cypher_async(): @mark_async_test def test_cypher_syntax_error_async(): - jim = User2(email="jim1@test.com").save_async() + jim = User2(email="jim1@test.com").save() try: - jim.cypher_async(f"MATCH a WHERE {adb.get_id_method()}(a)={ self} RETURN xx") + jim.cypher(f"MATCH a WHERE {adb.get_id_method()}(a)={ self} RETURN xx") except CypherError as e: assert hasattr(e, "message") assert hasattr(e, "code") @@ -82,19 +82,19 @@ def test_pandas_not_installed_async(hide_available_pkg): ): from neomodel.integration.pandas import to_dataframe - _ = to_dataframe(adb.cypher_query_async("MATCH (a) RETURN a.name AS name")) + _ = to_dataframe(adb.cypher_query("MATCH (a) RETURN a.name AS name")) @mark_async_test def test_pandas_integration_async(): from neomodel.integration.pandas import to_dataframe, to_series - jimla = UserPandas(email="jimla@test.com", name="jimla").save_async() - jimlo = UserPandas(email="jimlo@test.com", name="jimlo").save_async() + jimla = UserPandas(email="jimla@test.com", name="jimla").save() + jimlo = UserPandas(email="jimlo@test.com", name="jimlo").save() # Test to_dataframe df = to_dataframe( - adb.cypher_query_async( + adb.cypher_query( "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" ) ) @@ -105,7 +105,7 @@ def test_pandas_integration_async(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( - adb.cypher_query_async( + adb.cypher_query( "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" ), index=df["email"], @@ -116,7 +116,7 @@ def test_pandas_integration_async(): # Next test to_series series = to_series( - adb.cypher_query_async("MATCH (a:UserPandas) RETURN a.name AS name") + adb.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") ) assert isinstance(series, Series) @@ -134,18 +134,18 @@ def test_numpy_not_installed_async(hide_available_pkg): ): from neomodel.integration.numpy import to_ndarray - _ = to_ndarray(adb.cypher_query_async("MATCH (a) RETURN a.name AS name")) + _ = to_ndarray(adb.cypher_query("MATCH (a) RETURN a.name AS name")) @mark_async_test def test_numpy_integration_async(): from neomodel.integration.numpy import to_ndarray - jimly = UserNP(email="jimly@test.com", name="jimly").save_async() - jimlu = UserNP(email="jimlu@test.com", name="jimlu").save_async() + jimly = UserNP(email="jimly@test.com", name="jimly").save() + jimlu = UserNP(email="jimlu@test.com", name="jimlu").save() array = to_ndarray( - adb.cypher_query_async( + adb.cypher_query( "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" ) ) diff --git a/test/test_alias.py b/test/test_alias.py index 78f39969..c0c4877f 100644 --- a/test/test_alias.py +++ b/test/test_alias.py @@ -13,13 +13,13 @@ class AliasTestNode(StructuredNodeAsync): def test_property_setup_hook(): - tim = AliasTestNode(long_name="tim").save_async() + tim = AliasTestNode(long_name="tim").save() assert AliasTestNode.setup_hook_called assert tim.name == "tim" def test_alias(): - jim = AliasTestNode(full_name="Jim").save_async() + jim = AliasTestNode(full_name="Jim").save() assert jim.name == "Jim" assert jim.full_name == "Jim" assert "full_name" not in AliasTestNode.deflate(jim.__properties__) diff --git a/test/test_batch.py b/test/test_batch.py index cf0b6a4a..3085de3d 100644 --- a/test/test_batch.py +++ b/test/test_batch.py @@ -21,15 +21,11 @@ class UniqueUser(StructuredNodeAsync): def test_unique_id_property_batch(): - users = UniqueUser.create_async( - {"name": "bob", "age": 2}, {"name": "ben", "age": 3} - ) + users = UniqueUser.create({"name": "bob", "age": 2}, {"name": "ben", "age": 3}) assert users[0].uid != users[1].uid - users = UniqueUser.get_or_create_async( - {"uid": users[0].uid}, {"name": "bill", "age": 4} - ) + users = UniqueUser.get_or_create({"uid": users[0].uid}, {"name": "bill", "age": 4}) assert users[0].name == "bob" assert users[1].uid @@ -41,7 +37,7 @@ class Customer(StructuredNodeAsync): def test_batch_create(): - users = Customer.create_async( + users = Customer.create( {"email": "jim1@aol.com", "age": 11}, {"email": "jim2@aol.com", "age": 7}, {"email": "jim3@aol.com", "age": 9}, @@ -56,7 +52,7 @@ def test_batch_create(): def test_batch_create_or_update(): - users = Customer.create_or_update_async( + users = Customer.create_or_update( {"email": "merge1@aol.com", "age": 11}, {"email": "merge2@aol.com"}, {"email": "merge3@aol.com", "age": 1}, @@ -66,7 +62,7 @@ def test_batch_create_or_update(): assert users[1] == users[3] assert Customer.nodes.get(email="merge1@aol.com").age == 11 - more_users = Customer.create_or_update_async( + more_users = Customer.create_or_update( {"email": "merge1@aol.com", "age": 22}, {"email": "merge4@aol.com", "age": None}, ) @@ -78,7 +74,7 @@ def test_batch_create_or_update(): def test_batch_validation(): # test validation in batch create with raises(DeflateError): - Customer.create_async( + Customer.create( {"email": "jim1@aol.com", "age": "x"}, ) @@ -87,12 +83,12 @@ def test_batch_index_violation(): for u in Customer.nodes.all(): u.delete() - users = Customer.create_async( + users = Customer.create( {"email": "jim6@aol.com", "age": 3}, ) assert users with raises(UniqueProperty): - Customer.create_async( + Customer.create( {"email": "jim6@aol.com", "age": 3}, {"email": "jim7@aol.com", "age": 5}, ) @@ -112,11 +108,11 @@ class Person(StructuredNodeAsync): def test_get_or_create_with_rel(): - bob = Person.get_or_create_async({"name": "Bob"})[0] - bobs_gizmo = Dog.get_or_create_async({"name": "Gizmo"}, relationship=bob.pets) + bob = Person.get_or_create({"name": "Bob"})[0] + bobs_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) - tim = Person.get_or_create_async({"name": "Tim"})[0] - tims_gizmo = Dog.get_or_create_async({"name": "Gizmo"}, relationship=tim.pets) + tim = Person.get_or_create({"name": "Tim"})[0] + tims_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) # not the same gizmo assert bobs_gizmo[0] != tims_gizmo[0] diff --git a/test/test_cardinality.py b/test/test_cardinality.py index 119aef6c..b21e652e 100644 --- a/test/test_cardinality.py +++ b/test/test_cardinality.py @@ -40,10 +40,10 @@ class ToothBrush(StructuredNodeAsync): def test_cardinality_zero_or_more(): - m = Monkey(name="tim").save_async() + m = Monkey(name="tim").save() assert m.dryers.all() == [] assert m.dryers.single() is None - h = HairDryer(version=1).save_async() + h = HairDryer(version=1).save() m.dryers.connect(h) assert len(m.dryers.all()) == 1 @@ -53,7 +53,7 @@ def test_cardinality_zero_or_more(): assert m.dryers.all() == [] assert m.dryers.single() is None - h2 = HairDryer(version=2).save_async() + h2 = HairDryer(version=2).save() m.dryers.connect(h) m.dryers.connect(h2) m.dryers.disconnect_all() @@ -62,16 +62,16 @@ def test_cardinality_zero_or_more(): def test_cardinality_zero_or_one(): - m = Monkey(name="bob").save_async() + m = Monkey(name="bob").save() assert m.driver.all() == [] assert m.driver.single() is None - h = ScrewDriver(version=1).save_async() + h = ScrewDriver(version=1).save() m.driver.connect(h) assert len(m.driver.all()) == 1 assert m.driver.single().version == 1 - j = ScrewDriver(version=2).save_async() + j = ScrewDriver(version=2).save() with raises(AttemptedCardinalityViolation): m.driver.connect(j) @@ -95,7 +95,7 @@ def test_cardinality_zero_or_one(): def test_cardinality_one_or_more(): - m = Monkey(name="jerry").save_async() + m = Monkey(name="jerry").save() with raises(CardinalityViolation): m.car.all() @@ -103,7 +103,7 @@ def test_cardinality_one_or_more(): with raises(CardinalityViolation): m.car.single() - c = Car(version=2).save_async() + c = Car(version=2).save() m.car.connect(c) assert m.car.single().version == 2 @@ -113,7 +113,7 @@ def test_cardinality_one_or_more(): with raises(AttemptedCardinalityViolation): m.car.disconnect(c) - d = Car(version=3).save_async() + d = Car(version=3).save() m.car.connect(d) cars = m.car.all() assert len(cars) == 2 @@ -124,7 +124,7 @@ def test_cardinality_one_or_more(): def test_cardinality_one(): - m = Monkey(name="jerry").save_async() + m = Monkey(name="jerry").save() with raises( CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: none." @@ -134,11 +134,11 @@ def test_cardinality_one(): with raises(CardinalityViolation): m.toothbrush.single() - b = ToothBrush(name="Jim").save_async() + b = ToothBrush(name="Jim").save() m.toothbrush.connect(b) assert m.toothbrush.single().name == "Jim" - x = ToothBrush(name="Jim").save_async + x = ToothBrush(name="Jim").save with raises(AttemptedCardinalityViolation): m.toothbrush.connect(x) diff --git a/test/test_connection.py b/test/test_connection.py index 26e74e07..fb5fad42 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -15,7 +15,7 @@ def setup_teardown(): # Teardown actions after tests have run # Reconnect to initial URL for potential subsequent tests adb.close_connection() - adb.set_connection_async(url=config.DATABASE_URL) + adb.set_connection(url=config.DATABASE_URL) @pytest.fixture(autouse=True, scope="session") @@ -43,26 +43,26 @@ class Pastry(StructuredNodeAsync): def test_set_connection_driver_works(): # Verify that current connection is up - assert Pastry(name="Chocolatine").save_async() + assert Pastry(name="Chocolatine").save() adb.close_connection() # Test connection using a driver - adb.set_connection_async( + adb.set_connection( driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) ) - assert Pastry(name="Croissant").save_async() + assert Pastry(name="Croissant").save() def test_config_driver_works(): # Verify that current connection is up - assert Pastry(name="Chausson aux pommes").save_async() + assert Pastry(name="Chausson aux pommes").save() adb.close_connection() # Test connection using a driver defined in config driver = GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) config.DRIVER = driver - assert Pastry(name="Grignette").save_async() + assert Pastry(name="Grignette").save() # Clear config # No need to close connection - pytest teardown will do it @@ -79,7 +79,7 @@ def test_connect_to_non_default_database(): adb.close_connection() # Set database name in url - for url init only - adb.set_connection_async(url=f"{config.DATABASE_URL}/{database_name}") + adb.set_connection(url=f"{config.DATABASE_URL}/{database_name}") assert get_current_database_name() == "pastries" adb.close_connection() @@ -88,13 +88,13 @@ def test_connect_to_non_default_database(): config.DATABASE_NAME = database_name # url init - adb.set_connection_async(url=config.DATABASE_URL) + adb.set_connection(url=config.DATABASE_URL) assert get_current_database_name() == "pastries" adb.close_connection() # driver init - adb.set_connection_async( + adb.set_connection( driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) ) assert get_current_database_name() == "pastries" @@ -112,7 +112,7 @@ def test_wrong_url_format(url): ValueError, match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}", ): - adb.set_connection_async(url=url) + adb.set_connection(url=url) @pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) @@ -134,4 +134,4 @@ def _set_connection(protocol): AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" - adb.set_connection_async(url=database_url) + adb.set_connection(url=database_url) diff --git a/test/test_contrib/test_semi_structured.py b/test/test_contrib/test_semi_structured.py index c04530f8..fe73a2bd 100644 --- a/test/test_contrib/test_semi_structured.py +++ b/test/test_contrib/test_semi_structured.py @@ -13,13 +13,13 @@ class Dummy(SemiStructuredNode): def test_to_save_to_model_with_required_only(): u = UserProf(email="dummy@test.com") - assert u.save_async() + assert u.save() def test_save_to_model_with_extras(): u = UserProf(email="jim@test.com", age=3, bar=99) u.foo = True - assert u.save_async() + assert u.save() u = UserProf.nodes.get(age=3) assert u.foo is True assert u.bar == 99 @@ -27,4 +27,4 @@ def test_save_to_model_with_extras(): def test_save_empty_model(): dummy = Dummy() - assert dummy.save_async() + assert dummy.save() diff --git a/test/test_contrib/test_spatial_properties.py b/test/test_contrib/test_spatial_properties.py index e009527a..b7986b58 100644 --- a/test/test_contrib/test_spatial_properties.py +++ b/test/test_contrib/test_spatial_properties.py @@ -182,7 +182,7 @@ class LocalisableEntity(neomodel.StructuredNodeAsync): ) # Save an object - an_object = LocalisableEntity().save_async() + an_object = LocalisableEntity().save() coords = an_object.location.coords[0] # Retrieve it retrieved_object = LocalisableEntity.nodes.get(identifier=an_object.identifier) @@ -220,7 +220,7 @@ class AnotherLocalisableEntity(neomodel.StructuredNodeAsync): neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), ] - ).save_async() + ).save() retrieved_object = AnotherLocalisableEntity.nodes.get( identifier=an_object.identifier @@ -255,7 +255,7 @@ class TestStorageRetrievalProperty(neomodel.StructuredNodeAsync): a_restaurant = TestStorageRetrievalProperty( description="Milliways", location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)), - ).save_async() + ).save() a_property = TestStorageRetrievalProperty.nodes.get( location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)) diff --git a/test/test_database_management.py b/test/test_database_management.py index e3f89f19..6da05818 100644 --- a/test/test_database_management.py +++ b/test/test_database_management.py @@ -26,28 +26,28 @@ class Venue(StructuredNodeAsync): def test_clear_database(): - venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save_async() - city = City(name="London").save_async() + venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save() + city = City(name="London").save() venue.in_city.connect(city) # Clear only the data - adb.clear_neo4j_database_async() + adb.clear_neo4j_database() database_is_populated, _ = adb.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) assert database_is_populated[0][0] is False - indexes = adb.lise_indexes_async(exclude_token_lookup=True) - constraints = adb.list_constraints_async() + indexes = adb.list_indexes(exclude_token_lookup=True) + constraints = adb.list_constraints() assert len(indexes) > 0 assert len(constraints) > 0 # Clear constraints and indexes too - adb.clear_neo4j_database_async(clear_constraints=True, clear_indexes=True) + adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) - indexes = adb.lise_indexes_async(exclude_token_lookup=True) - constraints = adb.list_constraints_async() + indexes = adb.list_indexes(exclude_token_lookup=True) + constraints = adb.list_constraints() assert len(indexes) == 0 assert len(constraints) == 0 @@ -58,19 +58,19 @@ def test_change_password(): prev_url = f"bolt://neo4j:{prev_password}@localhost:7687" new_url = f"bolt://neo4j:{new_password}@localhost:7687" - adb.change_neo4j_password_async("neo4j", new_password) + adb.change_neo4j_password("neo4j", new_password) adb.close_connection() - adb.set_connection_async(url=new_url) + adb.set_connection(url=new_url) adb.close_connection() with pytest.raises(AuthError): - adb.set_connection_async(url=prev_url) + adb.set_connection(url=prev_url) adb.close_connection() - adb.set_connection_async(url=new_url) - adb.change_neo4j_password_async("neo4j", prev_password) + adb.set_connection(url=new_url) + adb.change_neo4j_password("neo4j", prev_password) adb.close_connection() - adb.set_connection_async(url=prev_url) + adb.set_connection(url=prev_url) diff --git a/test/test_driver_options.py b/test/test_driver_options.py index d896d257..e2fba00f 100644 --- a/test/test_driver_options.py +++ b/test/test_driver_options.py @@ -11,7 +11,7 @@ ) def test_impersonate(): with adb.impersonate(user="troygreene"): - results, _ = adb.cypher_query_async("RETURN 'Doo Wacko !'") + results, _ = adb.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" @@ -21,7 +21,7 @@ def test_impersonate(): def test_impersonate_unauthorized(): with adb.impersonate(user="unknownuser"): with raises(ClientError): - _ = adb.cypher_query_async("RETURN 'Gabagool'") + _ = adb.cypher_query("RETURN 'Gabagool'") @pytest.mark.skipif( @@ -30,14 +30,14 @@ def test_impersonate_unauthorized(): def test_impersonate_multiple_transactions(): with adb.impersonate(user="troygreene"): with adb.transaction: - results, _ = adb.cypher_query_async("RETURN 'Doo Wacko !'") + results, _ = adb.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" with adb.transaction: - results, _ = adb.cypher_query_async("SHOW CURRENT USER") + results, _ = adb.cypher_query("SHOW CURRENT USER") assert results[0][0] == "troygreene" - results, _ = adb.cypher_query_async("SHOW CURRENT USER") + results, _ = adb.cypher_query("SHOW CURRENT USER") assert results[0][0] == "neo4j" @@ -47,4 +47,4 @@ def test_impersonate_multiple_transactions(): def test_impersonate_community(): with raises(FeatureNotSupported): with adb.impersonate(user="troygreene"): - _ = adb.cypher_query_async("RETURN 'Gabagoogoo'") + _ = adb.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/test_hooks.py b/test/test_hooks.py index 06c49247..c77e2845 100644 --- a/test/test_hooks.py +++ b/test/test_hooks.py @@ -23,8 +23,8 @@ def post_delete(self): def test_hooks(): - ht = HookTest(name="k").save_async() - ht.delete_async() + ht = HookTest(name="k").save() + ht.delete() assert "pre_save" in HOOKS_CALLED assert "post_save" in HOOKS_CALLED assert "post_create" in HOOKS_CALLED diff --git a/test/test_indexing.py b/test/test_indexing.py index 9cf90b60..9611df9a 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -17,10 +17,10 @@ class Human(StructuredNodeAsync): def test_unique_error(): - adb.install_labels_async(Human) - Human(name="j1m", age=13).save_async() + adb.install_labels(Human) + Human(name="j1m", age=13).save() try: - Human(name="j1m", age=14).save_async() + Human(name="j1m", age=14).save() except UniqueProperty as e: assert str(e).find("j1m") assert str(e).find("name") @@ -32,22 +32,22 @@ def test_unique_error(): not adb.edition_is_enterprise(), reason="Skipping test for community edition" ) def test_existence_constraint_error(): - adb.cypher_query_async( + adb.cypher_query( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" ) with raises(ConstraintValidationFailed, match=r"must have the property"): - Human(name="Scarlett").save_async() + Human(name="Scarlett").save() - adb.cypher_query_async("DROP CONSTRAINT test_existence_constraint") + adb.cypher_query("DROP CONSTRAINT test_existence_constraint") def test_optional_properties_dont_get_indexed(): - Human(name="99", age=99).save_async() + Human(name="99", age=99).save() h = Human.nodes.get(age=99) assert h assert h.name == "99" - Human(name="98", age=98).save_async() + Human(name="98", age=98).save() h = Human.nodes.get(age=98) assert h assert h.name == "98" @@ -55,7 +55,7 @@ def test_optional_properties_dont_get_indexed(): def test_escaped_chars(): _name = "sarah:test" - Human(name=_name, age=3).save_async() + Human(name=_name, age=3).save() r = Human.nodes.filter(name=_name) assert r assert r[0].name == _name @@ -71,7 +71,7 @@ class Giraffe(StructuredNodeAsync): __label__ = "Giraffes" name = StringProperty(unique_index=True) - jim = Giraffe(name="timothy").save_async() + jim = Giraffe(name="timothy").save() node = Giraffe.nodes.get(name="timothy") assert node.name == jim.name diff --git a/test/test_issue112.py b/test/test_issue112.py index 295fe239..dd569fdc 100644 --- a/test/test_issue112.py +++ b/test/test_issue112.py @@ -6,8 +6,8 @@ class SomeModel(StructuredNodeAsync): def test_len_relationship(): - t1 = SomeModel().save_async() - t2 = SomeModel().save_async() + t1 = SomeModel().save() + t2 = SomeModel().save() t1.test.connect(t2) l = len(t1.test.all()) diff --git a/test/test_issue283.py b/test/test_issue283.py index 8f4bb29a..d42d3fb8 100644 --- a/test/test_issue283.py +++ b/test/test_issue283.py @@ -88,15 +88,9 @@ def test_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create_async({"name": "Happy", "expertise": "Unicorns"})[ - 0 - ] - C = TechnicalPerson.get_or_create_async({"name": "Sleepy", "expertise": "Pillows"})[ - 0 - ] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] + C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] # Add connections A.friends_with.connect(B) @@ -107,9 +101,9 @@ def test_automatic_result_resolution(): # TechnicalPerson (!NOT basePerson!) assert type(A.friends_with[0]) is TechnicalPerson - A.delete_async() - B.delete_async() - C.delete_async() + A.delete() + B.delete() + C.delete() def test_recursive_automatic_result_resolution(): @@ -120,18 +114,12 @@ def test_recursive_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpier", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create_async( - {"name": "Happier", "expertise": "Grumpiness"} - )[0] - C = TechnicalPerson.get_or_create_async( - {"name": "Sleepier", "expertise": "Pillows"} - )[0] - D = TechnicalPerson.get_or_create_async( - {"name": "Sneezier", "expertise": "Pillows"} - )[0] + A = TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"})[ + 0 + ] + B = TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"})[0] + C = TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"})[0] + D = TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"})[0] # Retrieve mixed results, both at the top level and nested L, _ = neomodel.adb.cypher_query( @@ -152,10 +140,10 @@ def test_recursive_automatic_result_resolution(): # Assert that primitive data types remain primitive data types assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) - A.delete_async() - B.delete_async() - C.delete_async() - D.delete_async() + A.delete() + B.delete() + C.delete() + D.delete() def test_validation_with_inheritance_from_db(): @@ -166,21 +154,15 @@ def test_validation_with_inheritance_from_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create_async({"name": "Happy", "expertise": "Unicorns"})[ - 0 - ] - C = TechnicalPerson.get_or_create_async({"name": "Sleepy", "expertise": "Pillows"})[ - 0 - ] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] + C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] # Pilot Persons - D = PilotPerson.get_or_create_async( + D = PilotPerson.get_or_create( {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} )[0] - E = PilotPerson.get_or_create_async( + E = PilotPerson.get_or_create( {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} )[0] @@ -209,11 +191,11 @@ def test_validation_with_inheritance_from_db(): ) assert type(D.friends_with[0]) is PilotPerson - A.delete_async() - B.delete_async() - C.delete_async() - D.delete_async() - E.delete_async() + A.delete() + B.delete() + C.delete() + D.delete() + E.delete() def test_validation_enforcement_to_db(): @@ -223,26 +205,20 @@ def test_validation_enforcement_to_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create_async({"name": "Happy", "expertise": "Unicorns"})[ - 0 - ] - C = TechnicalPerson.get_or_create_async({"name": "Sleepy", "expertise": "Pillows"})[ - 0 - ] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] + C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] # Pilot Persons - D = PilotPerson.get_or_create_async( + D = PilotPerson.get_or_create( {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} )[0] - E = PilotPerson.get_or_create_async( + E = PilotPerson.get_or_create( {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} )[0] # Some Person - F = SomePerson(car_color="Blue").save_async() + F = SomePerson(car_color="Blue").save() # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine A.friends_with.connect(B) @@ -257,12 +233,12 @@ def test_validation_enforcement_to_db(): with pytest.raises(ValueError): A.friends_with.connect(F) - A.delete_async() - B.delete_async() - C.delete_async() - D.delete_async() - E.delete_async() - F.delete_async() + A.delete() + B.delete() + C.delete() + D.delete() + E.delete() + F.delete() def test_failed_result_resolution(): @@ -276,12 +252,10 @@ class RandomPerson(BasePerson): randomness = neomodel.FloatProperty(default=random.random) # A Technical Person... - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] # A Random Person... - B = RandomPerson.get_or_create_async({"name": "Mad Hatter"})[0] + B = RandomPerson.get_or_create({"name": "Mad Hatter"})[0] A.friends_with.connect(B) @@ -290,9 +264,7 @@ class RandomPerson(BasePerson): del neomodel.adb._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] # Now try to instantiate a RandomPerson - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] with pytest.raises( neomodel.exceptions.NodeClassNotDefined, match=r"Node with labels .* does not resolve to any of the known objects.*", @@ -300,8 +272,8 @@ class RandomPerson(BasePerson): for some_friend in A.friends_with: print(some_friend.name) - A.delete_async() - B.delete_async() + A.delete() + B.delete() def test_node_label_mismatch(): @@ -317,13 +289,9 @@ class UltraTechnicalPerson(SuperTechnicalPerson): ultraness = neomodel.FloatProperty(default=3.1415928) # Create a TechnicalPerson... - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] # ...that is connected to an UltraTechnicalPerson - F = UltraTechnicalPerson( - name="Chewbaka", expertise="Aarrr wgh ggwaaah" - ).save_async() + F = UltraTechnicalPerson(name="Chewbaka", expertise="Aarrr wgh ggwaaah").save() A.friends_with.connect(F) # Forget about the UltraTechnicalPerson @@ -341,9 +309,7 @@ class UltraTechnicalPerson(SuperTechnicalPerson): # Recall a TechnicalPerson and enumerate its friends. # One of them is UltraTechnicalPerson which would be returned as a valid # node to a friends_with query but is currently unknown to the node class registry. - A = TechnicalPerson.get_or_create_async( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] with pytest.raises(neomodel.exceptions.NodeClassNotDefined): for some_friend in A.friends_with: print(some_friend.name) @@ -373,11 +339,11 @@ def test_relationship_result_resolution(): A query returning a "Relationship" object can now instantiate it to a data model class """ # Test specific data - A = PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save_async() - B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save_async() - C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save_async() - D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save_async() - E = PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save_async() + A = PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save() + B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save() + C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save() + D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save() + E = PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save() A.friends_with.connect(B) B.friends_with.connect(C) @@ -415,9 +381,9 @@ class ExtendedSomePerson(SomePerson): ) # Test specific data - A = ExtendedSomePerson(name="Michael Knight", car_color="Black").save_async() - B = ExtendedSomePerson(name="Luke Duke", car_color="Orange").save_async() - C = ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save_async() + A = ExtendedSomePerson(name="Michael Knight", car_color="Black").save() + B = ExtendedSomePerson(name="Luke Duke", car_color="Orange").save() + C = ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save() A.friends_with.connect(B) A.friends_with.connect(C) diff --git a/test/test_issue600.py b/test/test_issue600.py index 5d760bf0..c6d624a4 100644 --- a/test/test_issue600.py +++ b/test/test_issue600.py @@ -57,9 +57,9 @@ class RelationshipDefinerParentLast(neomodel.StructuredNodeAsync): # Test cases def test_relationship_definer_second_sibling(): # Create a few entities - A = RelationshipDefinerSecondSibling.get_or_create_async({})[0] - B = RelationshipDefinerSecondSibling.get_or_create_async({})[0] - C = RelationshipDefinerSecondSibling.get_or_create_async({})[0] + A = RelationshipDefinerSecondSibling.get_or_create({})[0] + B = RelationshipDefinerSecondSibling.get_or_create({})[0] + C = RelationshipDefinerSecondSibling.get_or_create({})[0] # Add connections A.rel_1.connect(B) @@ -67,16 +67,16 @@ def test_relationship_definer_second_sibling(): C.rel_3.connect(A) # Clean up - A.delete_async() - B.delete_async() - C.delete_async() + A.delete() + B.delete() + C.delete() def test_relationship_definer_parent_last(): # Create a few entities - A = RelationshipDefinerParentLast.get_or_create_async({})[0] - B = RelationshipDefinerParentLast.get_or_create_async({})[0] - C = RelationshipDefinerParentLast.get_or_create_async({})[0] + A = RelationshipDefinerParentLast.get_or_create({})[0] + B = RelationshipDefinerParentLast.get_or_create({})[0] + C = RelationshipDefinerParentLast.get_or_create({})[0] # Add connections A.rel_1.connect(B) @@ -84,6 +84,6 @@ def test_relationship_definer_parent_last(): C.rel_3.connect(A) # Clean up - A.delete_async() - B.delete_async() - C.delete_async() + A.delete() + B.delete() + C.delete() diff --git a/test/test_label_drop.py b/test/test_label_drop.py index 5d3dc13a..e62f5caf 100644 --- a/test/test_label_drop.py +++ b/test/test_label_drop.py @@ -12,16 +12,16 @@ class ConstraintAndIndex(StructuredNodeAsync): def test_drop_labels(): - constraints_before = adb.list_constraints_async() - indexes_before = adb.list_indexes_async(exclude_token_lookup=True) + constraints_before = adb.list_constraints() + indexes_before = adb.list_indexes(exclude_token_lookup=True) assert len(constraints_before) > 0 assert len(indexes_before) > 0 - adb.remove_all_labels_async() + adb.remove_all_labels() - constraints = adb.list_constraints_async() - indexes = adb.list_indexes_async(exclude_token_lookup=True) + constraints = adb.list_constraints() + indexes = adb.list_indexes(exclude_token_lookup=True) assert len(constraints) == 0 assert len(indexes) == 0 @@ -34,12 +34,12 @@ def test_drop_labels(): elif constraint["type"] == "NODE_KEY": constraint_type_clause = "NODE KEY" - adb.cypher_query_async( + adb.cypher_query( f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}' ) for index in indexes_before: try: - adb.cypher_query_async( + adb.cypher_query( f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})' ) except ClientError: diff --git a/test/test_label_install.py b/test/test_label_install.py index 7e1c3dc6..abbff9a5 100644 --- a/test/test_label_install.py +++ b/test/test_label_install.py @@ -49,9 +49,9 @@ class SomeNotUniqueNode(StructuredNodeAsync): def test_labels_were_not_installed(): - bob = NodeWithConstraint(name="bob").save_async() - bob2 = NodeWithConstraint(name="bob").save_async() - bob3 = NodeWithConstraint(name="bob").save_async() + bob = NodeWithConstraint(name="bob").save() + bob2 = NodeWithConstraint(name="bob").save() + bob3 = NodeWithConstraint(name="bob").save() assert bob.element_id != bob3.element_id for n in NodeWithConstraint.nodes.all(): @@ -59,16 +59,16 @@ def test_labels_were_not_installed(): def test_install_all(): - adb.drop_constraints_async() - adb.install_labels_async(AbstractNode) + adb.drop_constraints() + adb.install_labels(AbstractNode) # run install all labels - adb.install_all_labels_async() + adb.install_all_labels() - indexes = adb.list_indexes_async() + indexes = adb.list_indexes() index_names = [index["name"] for index in indexes] assert "index_INDEXED_REL_indexed_rel_prop" in index_names - constraints = adb.list_constraints_async() + constraints = adb.list_constraints() constraint_names = [constraint["name"] for constraint in constraints] assert "constraint_unique_NodeWithConstraint_name" in constraint_names assert "constraint_unique_SomeNotUniqueNode_id" in constraint_names @@ -81,21 +81,21 @@ def test_install_label_twice(capsys): expected_std_out = ( "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" ) - adb.install_labels_async(AbstractNode) - adb.install_labels_async(AbstractNode) + adb.install_labels(AbstractNode) + adb.install_labels(AbstractNode) - adb.install_labels_async(NodeWithIndex) - adb.install_labels_async(NodeWithIndex, quiet=False) + adb.install_labels(NodeWithIndex) + adb.install_labels(NodeWithIndex, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - adb.install_labels_async(NodeWithConstraint) - adb.install_labels_async(NodeWithConstraint, quiet=False) + adb.install_labels(NodeWithConstraint) + adb.install_labels(NodeWithConstraint, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - adb.install_labels_async(OtherNodeWithRelationship) - adb.install_labels_async(OtherNodeWithRelationship, quiet=False) + adb.install_labels(OtherNodeWithRelationship) + adb.install_labels(OtherNodeWithRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out @@ -109,15 +109,15 @@ class OtherNodeWithUniqueIndexRelationship(StructuredNodeAsync): NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship ) - adb.install_labels_async(OtherNodeWithUniqueIndexRelationship) - adb.install_labels_async(OtherNodeWithUniqueIndexRelationship, quiet=False) + adb.install_labels(OtherNodeWithUniqueIndexRelationship) + adb.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out def test_install_labels_db_property(capsys): - adb.drop_constraints_async() - adb.install_labels_async(SomeNotUniqueNode, quiet=False) + adb.drop_constraints() + adb.install_labels(SomeNotUniqueNode, quiet=False) captured = capsys.readouterr() assert "id" in captured.out # make sure that the id_ constraint doesn't exist @@ -166,10 +166,10 @@ class NodeWithUniqueIndexRelationship(StructuredNodeAsync): model=UniqueIndexRelationshipBis, ) - adb.install_labels_async(UniqueIndexRelationshipBis) - node1 = NodeWithUniqueIndexRelationship().save_async() - node2 = TargetNodeForUniqueIndexRelationship().save_async() - node3 = TargetNodeForUniqueIndexRelationship().save_async() + adb.install_labels(UniqueIndexRelationshipBis) + node1 = NodeWithUniqueIndexRelationship().save() + node2 = TargetNodeForUniqueIndexRelationship().save() + node3 = TargetNodeForUniqueIndexRelationship().save() rel1 = node1.has_rel.connect(node2, {"name": "rel1"}) with pytest.raises( @@ -180,7 +180,7 @@ class NodeWithUniqueIndexRelationship(StructuredNodeAsync): def _drop_constraints_for_label_and_property(label: str = None, property: str = None): - results, meta = adb.cypher_query_async("SHOW CONSTRAINTS") + results, meta = adb.cypher_query("SHOW CONSTRAINTS") results_as_dict = [dict(zip(meta, row)) for row in results] constraint_names = [ constraint @@ -188,6 +188,6 @@ def _drop_constraints_for_label_and_property(label: str = None, property: str = if constraint["labelsOrTypes"] == label and constraint["properties"] == property ] for constraint_name in constraint_names: - adb.cypher_query_async(f"DROP CONSTRAINT {constraint_name}") + adb.cypher_query(f"DROP CONSTRAINT {constraint_name}") return constraint_names diff --git a/test/test_match_api.py b/test/test_match_api.py index f5b43c90..519899a8 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -46,7 +46,7 @@ class Extension(StructuredNodeAsync): def test_filter_exclude_via_labels(): - Coffee(name="Java", price=99).save_async() + Coffee(name="Java", price=99).save() node_set = NodeSet(Coffee) qb = QueryBuilder(node_set).build_ast() @@ -60,7 +60,7 @@ def test_filter_exclude_via_labels(): assert results[0].name == "Java" # with filter and exclude - Coffee(name="Kenco", price=3).save_async() + Coffee(name="Kenco", price=3).save() node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") qb = QueryBuilder(node_set).build_ast() @@ -72,8 +72,8 @@ def test_filter_exclude_via_labels(): def test_simple_has_via_label(): - nescafe = Coffee(name="Nescafe", price=99).save_async() - tesco = Supplier(name="Tesco", delivery_cost=2).save_async() + nescafe = Coffee(name="Nescafe", price=99).save() + tesco = Supplier(name="Tesco", delivery_cost=2).save() nescafe.suppliers.connect(tesco) ns = NodeSet(Coffee).has(suppliers=True) @@ -83,7 +83,7 @@ def test_simple_has_via_label(): assert len(results) == 1 assert results[0].name == "Nescafe" - Coffee(name="nespresso", price=99).save_async() + Coffee(name="nespresso", price=99).save() ns = NodeSet(Coffee).has(suppliers=False) qb = QueryBuilder(ns).build_ast() results = qb._execute() @@ -92,21 +92,21 @@ def test_simple_has_via_label(): def test_get(): - Coffee(name="1", price=3).save_async() + Coffee(name="1", price=3).save() assert Coffee.nodes.get(name="1") with raises(Coffee.DoesNotExist): Coffee.nodes.get(name="2") - Coffee(name="2", price=3).save_async() + Coffee(name="2", price=3).save() with raises(MultipleNodesReturned): Coffee.nodes.get(price=3) def test_simple_traverse_with_filter(): - nescafe = Coffee(name="Nescafe2", price=99).save_async() - tesco = Supplier(name="Sainsburys", delivery_cost=2).save_async() + nescafe = Coffee(name="Nescafe2", price=99).save() + tesco = Supplier(name="Sainsburys", delivery_cost=2).save() nescafe.suppliers.connect(tesco) qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) @@ -121,10 +121,10 @@ def test_simple_traverse_with_filter(): def test_double_traverse(): - nescafe = Coffee(name="Nescafe plus", price=99).save_async() - tesco = Supplier(name="Asda", delivery_cost=2).save_async() + nescafe = Coffee(name="Nescafe plus", price=99).save() + tesco = Supplier(name="Asda", delivery_cost=2).save() nescafe.suppliers.connect(tesco) - tesco.coffees.connect(Coffee(name="Decafe", price=2).save_async()) + tesco.coffees.connect(Coffee(name="Decafe", price=2).save()) ns = NodeSet(NodeSet(source=nescafe).suppliers.match()).coffees.match() qb = QueryBuilder(ns).build_ast() @@ -136,7 +136,7 @@ def test_double_traverse(): def test_count(): - Coffee(name="Nescafe Gold", price=99).save_async() + Coffee(name="Nescafe Gold", price=99).save() count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count() assert count > 0 @@ -144,7 +144,7 @@ def test_count(): def test_len_and_iter_and_bool(): iterations = 0 - Coffee(name="Icelands finest").save_async() + Coffee(name="Icelands finest").save() for c in Coffee.nodes: iterations += 1 @@ -159,9 +159,9 @@ def test_slice(): for c in Coffee.nodes: c.delete() - Coffee(name="Icelands finest").save_async() - Coffee(name="Britains finest").save_async() - Coffee(name="Japans finest").save_async() + Coffee(name="Icelands finest").save() + Coffee(name="Britains finest").save() + Coffee(name="Japans finest").save() assert len(list(Coffee.nodes.all()[1:])) == 2 assert len(list(Coffee.nodes.all()[:1])) == 1 @@ -173,9 +173,9 @@ def test_slice(): def test_issue_208(): # calls to match persist across queries. - b = Coffee(name="basics").save_async() - l = Supplier(name="lidl").save_async() - a = Supplier(name="aldi").save_async() + b = Coffee(name="basics").save() + l = Supplier(name="lidl").save() + a = Supplier(name="aldi").save() b.suppliers.connect(l, {"courier": "fedex"}) b.suppliers.connect(a, {"courier": "dhl"}) @@ -185,15 +185,15 @@ def test_issue_208(): def test_issue_589(): - node1 = Extension().save_async() - node2 = Extension().save_async() + node1 = Extension().save() + node2 = Extension().save() node1.extension.connect(node2) assert node2 in node1.extension.all() def test_contains(): - expensive = Coffee(price=1000, name="Pricey").save_async() - asda = Coffee(name="Asda", price=1).save_async() + expensive = Coffee(price=1000, name="Pricey").save() + asda = Coffee(name="Asda", price=1).save() assert expensive in Coffee.nodes.filter(price__gt=999) assert asda not in Coffee.nodes.filter(price__gt=999) @@ -211,9 +211,9 @@ def test_order_by(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Icelands finest", price=5).save_async() - c2 = Coffee(name="Britains finest", price=10).save_async() - c3 = Coffee(name="Japans finest", price=35).save_async() + c1 = Coffee(name="Icelands finest", price=5).save() + c2 = Coffee(name="Britains finest", price=10).save() + c3 = Coffee(name="Japans finest", price=35).save() assert Coffee.nodes.order_by("price").all()[0].price == 5 assert Coffee.nodes.order_by("-price").all()[0].price == 35 @@ -236,7 +236,7 @@ def test_order_by(): Coffee.nodes.order_by("id") # Test order by on a relationship - l = Supplier(name="lidl2").save_async() + l = Supplier(name="lidl2").save() l.coffees.connect(c1) l.coffees.connect(c2) l.coffees.connect(c3) @@ -251,10 +251,10 @@ def test_extra_filters(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Icelands finest", price=5, id_=1).save_async() - c2 = Coffee(name="Britains finest", price=10, id_=2).save_async() - c3 = Coffee(name="Japans finest", price=35, id_=3).save_async() - c4 = Coffee(name="US extra-fine", price=None, id_=4).save_async() + c1 = Coffee(name="Icelands finest", price=5, id_=1).save() + c2 = Coffee(name="Britains finest", price=10, id_=2).save() + c3 = Coffee(name="Japans finest", price=35, id_=3).save() + c4 = Coffee(name="US extra-fine", price=None, id_=4).save() coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]).all() assert len(coffees_5_10) == 2, "unexpected number of results" @@ -325,8 +325,8 @@ def test_empty_filters(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Super", price=5, id_=1).save_async() - c2 = Coffee(name="Puper", price=10, id_=2).save_async() + c1 = Coffee(name="Super", price=5, id_=1).save() + c2 = Coffee(name="Puper", price=10, id_=2).save() empty_filter = Coffee.nodes.filter() @@ -351,12 +351,12 @@ def test_q_filters(): for c in Coffee.nodes: c.delete() - c1 = Coffee(name="Icelands finest", price=5, id_=1).save_async() - c2 = Coffee(name="Britains finest", price=10, id_=2).save_async() - c3 = Coffee(name="Japans finest", price=35, id_=3).save_async() - c4 = Coffee(name="US extra-fine", price=None, id_=4).save_async() - c5 = Coffee(name="Latte", price=35, id_=5).save_async() - c6 = Coffee(name="Cappuccino", price=35, id_=6).save_async() + c1 = Coffee(name="Icelands finest", price=5, id_=1).save() + c2 = Coffee(name="Britains finest", price=10, id_=2).save() + c3 = Coffee(name="Japans finest", price=35, id_=3).save() + c4 = Coffee(name="US extra-fine", price=None, id_=4).save() + c5 = Coffee(name="Latte", price=35, id_=5).save() + c6 = Coffee(name="Cappuccino", price=35, id_=6).save() coffees_5_10 = Coffee.nodes.filter(Q(price=10) | Q(price=5)).all() assert len(coffees_5_10) == 2, "unexpected number of results" @@ -437,12 +437,12 @@ def test_qbase(): def test_traversal_filter_left_hand_statement(): - nescafe = Coffee(name="Nescafe2", price=99).save_async() - nescafe_gold = Coffee(name="Nescafe gold", price=11).save_async() + nescafe = Coffee(name="Nescafe2", price=99).save() + nescafe_gold = Coffee(name="Nescafe gold", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save_async() - biedronka = Supplier(name="Biedronka", delivery_cost=5).save_async() - lidl = Supplier(name="Lidl", delivery_cost=3).save_async() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() + biedronka = Supplier(name="Biedronka", delivery_cost=5).save() + lidl = Supplier(name="Lidl", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(biedronka) @@ -456,12 +456,12 @@ def test_traversal_filter_left_hand_statement(): def test_fetch_relations(): - arabica = Species(name="Arabica").save_async() - robusta = Species(name="Robusta").save_async() - nescafe = Coffee(name="Nescafe 1000", price=99).save_async() - nescafe_gold = Coffee(name="Nescafe 1001", price=11).save_async() + arabica = Species(name="Arabica").save() + robusta = Species(name="Robusta").save() + nescafe = Coffee(name="Nescafe 1000", price=99).save() + nescafe_gold = Coffee(name="Nescafe 1001", price=11).save() - tesco = Supplier(name="Sainsburys", delivery_cost=3).save_async() + tesco = Supplier(name="Sainsburys", delivery_cost=3).save() nescafe.suppliers.connect(tesco) nescafe_gold.suppliers.connect(tesco) nescafe.species.connect(arabica) diff --git a/test/test_models.py b/test/test_models.py index 3e804e3f..028150b0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -42,7 +42,7 @@ def __getitem__(self, item): class Issue233(BaseIssue233): uid = StringProperty(unique_index=True, required=True) - i = Issue233(uid="testgetitem").save_async() + i = Issue233(uid="testgetitem").save() assert i["uid"] == "testgetitem" @@ -53,7 +53,7 @@ def test_issue_72(): def test_required(): try: - User(age=3).save_async() + User(age=3).save() except RequiredProperty: assert True else: @@ -69,7 +69,7 @@ def test_repr_and_str(): def test_get_and_get_or_none(): u = User(email="robin@test.com", age=3) - assert u.save_async() + assert u.save() rob = User.nodes.get(email="robin@test.com") assert rob.email == "robin@test.com" assert rob.age == 3 @@ -83,9 +83,9 @@ def test_get_and_get_or_none(): def test_first_and_first_or_none(): u = User(email="matt@test.com", age=24) - assert u.save_async() + assert u.save() u2 = User(email="tbrady@test.com", age=40) - assert u2.save_async() + assert u2.save() tbrady = User.nodes.order_by("-age").first() assert tbrady.email == "tbrady@test.com" assert tbrady.age == 40 @@ -107,7 +107,7 @@ def test_bare_init_without_save(): def test_save_to_model(): u = User(email="jim@test.com", age=3) - assert u.save_async() + assert u.save() assert u.element_id is not None assert u.email == "jim@test.com" assert u.age == 3 @@ -115,43 +115,43 @@ def test_save_to_model(): def test_save_node_without_properties(): n = NodeWithoutProperty() - assert n.save_async() + assert n.save() assert n.element_id is not None def test_unique(): - adb.install_labels_async(User) - User(email="jim1@test.com", age=3).save_async() + adb.install_labels(User) + User(email="jim1@test.com", age=3).save() with raises(UniqueProperty): - User(email="jim1@test.com", age=3).save_async() + User(email="jim1@test.com", age=3).save() def test_update_unique(): - u = User(email="jimxx@test.com", age=3).save_async() - u.save_async() # this shouldn't fail + u = User(email="jimxx@test.com", age=3).save() + u.save() # this shouldn't fail def test_update(): - user = User(email="jim2@test.com", age=3).save_async() + user = User(email="jim2@test.com", age=3).save() assert user user.email = "jim2000@test.com" - user.save_async() + user.save() jim = User.nodes.get(email="jim2000@test.com") assert jim assert jim.email == "jim2000@test.com" def test_save_through_magic_property(): - user = User(email_alias="blah@test.com", age=8).save_async() + user = User(email_alias="blah@test.com", age=8).save() assert user.email_alias == "blah@test.com" user = User.nodes.get(email="blah@test.com") assert user.email == "blah@test.com" assert user.email_alias == "blah@test.com" - user1 = User(email="blah1@test.com", age=8).save_async() + user1 = User(email="blah1@test.com", age=8).save() assert user1.email_alias == "blah1@test.com" user1.email_alias = "blah2@test.com" - assert user1.save_async() + assert user1.save() user2 = User.nodes.get(email="blah2@test.com") assert user2 @@ -163,12 +163,12 @@ class Customer2(StructuredNodeAsync): def test_not_updated_on_unique_error(): - adb.install_labels_async(Customer2) - Customer2(email="jim@bob.com", age=7).save_async() - test = Customer2(email="jim1@bob.com", age=2).save_async() + adb.install_labels(Customer2) + Customer2(email="jim@bob.com", age=7).save() + test = Customer2(email="jim1@bob.com", age=2).save() test.email = "jim@bob.com" with raises(UniqueProperty): - test.save_async() + test.save() customers = Customer2.nodes.all() assert customers[0].email != customers[1].email assert Customer2.nodes.get(email="jim@bob.com").age == 7 @@ -180,18 +180,18 @@ class Customer3(Customer2): address = StringProperty() assert Customer3.__label__ == "Customer3" - c = Customer3(email="test@test.com").save_async() - assert "customers" in c.labels_async() - assert "Customer3" in c.labels_async() + c = Customer3(email="test@test.com").save() + assert "customers" in c.labels() + assert "Customer3" in c.labels() c = Customer2.nodes.get(email="test@test.com") assert isinstance(c, Customer2) - assert "customers" in c.labels_async() - assert "Customer3" in c.labels_async() + assert "customers" in c.labels() + assert "Customer3" in c.labels() def test_refresh(): - c = Customer2(email="my@email.com", age=16).save_async() + c = Customer2(email="my@email.com", age=16).save() c.my_custom_prop = "value" copy = Customer2.nodes.get(email="my@email.com") copy.age = 20 @@ -199,13 +199,13 @@ def test_refresh(): assert c.age == 16 - c.refresh_async() + c.refresh() assert c.age == 20 assert c.my_custom_prop == "value" c = Customer2.inflate(c.element_id) c.age = 30 - c.refresh_async() + c.refresh() assert c.age == 20 @@ -214,15 +214,15 @@ def test_refresh(): else: c = Customer2.inflate("4:xxxxxx:999") with raises(Customer2.DoesNotExist): - c.refresh_async() + c.refresh() def test_setting_value_to_none(): - c = Customer2(email="alice@bob.com", age=42).save_async() + c = Customer2(email="alice@bob.com", age=42).save() assert c.age is not None c.age = None - c.save_async() + c.save() copy = Customer2.nodes.get(email="alice@bob.com") assert copy.age is None @@ -238,16 +238,16 @@ class Shopper(User): def credit_account(self, amount): self.balance = self.balance + int(amount) - self.save_async() + self.save() - jim = Shopper(name="jimmy", balance=300).save_async() + jim = Shopper(name="jimmy", balance=300).save() jim.credit_account(50) assert Shopper.__label__ == "Shopper" assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 - assert len(jim.labels_async()) == 1 - assert jim.labels_async()[0] == "Shopper" + assert len(jim.labels()) == 1 + assert jim.labels()[0] == "Shopper" def test_inherited_optional_labels(): @@ -261,15 +261,15 @@ class ExtendedOptional(BaseOptional): def credit_account(self, amount): self.balance = self.balance + int(amount) - self.save_async() + self.save() - henry = ExtendedOptional(name="henry", balance=300).save_async() + henry = ExtendedOptional(name="henry", balance=300).save() henry.credit_account(50) assert ExtendedOptional.__label__ == "ExtendedOptional" assert henry.balance == 350 assert len(henry.inherited_labels()) == 2 - assert len(henry.labels_async()) == 2 + assert len(henry.labels()) == 2 assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"} @@ -289,21 +289,21 @@ def credit_account(self, amount): class Shopper2(StructuredNodeAsync, UserMixin, CreditMixin): pass - jim = Shopper2(name="jimmy", balance=300).save_async() + jim = Shopper2(name="jimmy", balance=300).save() jim.credit_account(50) assert Shopper2.__label__ == "Shopper2" assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 - assert len(jim.labels_async()) == 1 - assert jim.labels_async()[0] == "Shopper2" + assert len(jim.labels()) == 1 + assert jim.labels()[0] == "Shopper2" def test_date_property(): class DateTest(StructuredNodeAsync): birthdate = DateProperty() - user = DateTest(birthdate=datetime.now()).save_async() + user = DateTest(birthdate=datetime.now()).save() def test_reserved_property_keys(): diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 6d48de70..8b637bff 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -9,7 +9,7 @@ class ThingyMaBob(StructuredNodeAsync): def thing_create(name): name = str(name) - (thing,) = ThingyMaBob.get_or_create_async({"name": name}) + (thing,) = ThingyMaBob.get_or_create({"name": name}) return thing.name, name diff --git a/test/test_paths.py b/test/test_paths.py index 6e20d949..4b6bf447 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -43,25 +43,25 @@ def test_path_instantiation(): such a mapping is available. """ - c1 = CountryOfOrigin(code="GR").save_async() - c2 = CountryOfOrigin(code="FR").save_async() + c1 = CountryOfOrigin(code="GR").save() + c2 = CountryOfOrigin(code="FR").save() - ct1 = CityOfResidence(name="Athens", country=c1).save_async() - ct2 = CityOfResidence(name="Paris", country=c2).save_async() + ct1 = CityOfResidence(name="Athens", country=c1).save() + ct2 = CityOfResidence(name="Paris", country=c2).save() - p1 = PersonOfInterest(name="Bill", age=22).save_async() + p1 = PersonOfInterest(name="Bill", age=22).save() p1.country.connect(c1) p1.city.connect(ct1) - p2 = PersonOfInterest(name="Jean", age=28).save_async() + p2 = PersonOfInterest(name="Jean", age=28).save() p2.country.connect(c2) p2.city.connect(ct2) - p3 = PersonOfInterest(name="Bo", age=32).save_async() + p3 = PersonOfInterest(name="Bo", age=32).save() p3.country.connect(c1) p3.city.connect(ct2) - p4 = PersonOfInterest(name="Drop", age=16).save_async() + p4 = PersonOfInterest(name="Drop", age=16).save() p4.country.connect(c1) p4.city.connect(ct2) @@ -83,11 +83,11 @@ def test_path_instantiation(): assert type(path_rels[0]) is PersonLivesInCity assert type(path_rels[1]) is StructuredRel - c1.delete_async() - c2.delete_async() - ct1.delete_async() - ct2.delete_async() - p1.delete_async() - p2.delete_async() - p3.delete_async() - p4.delete_async() + c1.delete() + c2.delete() + ct1.delete() + ct2.delete() + p1.delete() + p2.delete() + p3.delete() + p4.delete() diff --git a/test/test_properties.py b/test/test_properties.py index e594cb60..76ae1bcc 100644 --- a/test/test_properties.py +++ b/test/test_properties.py @@ -66,13 +66,13 @@ class TestChoices(StructuredNodeAsync): sex = StringProperty(required=True, choices=SEXES) try: - TestChoices(sex="Z").save_async() + TestChoices(sex="Z").save() except DeflateError as e: assert "choice" in str(e) else: assert False, "DeflateError not raised." - node = TestChoices(sex="M").save_async() + node = TestChoices(sex="M").save() assert node.get_sex_display() == "Male" @@ -191,7 +191,7 @@ class DefaultTestValue(StructuredNodeAsync): a = DefaultTestValue() assert a.name_xx == "jim" - a.save_async() + a.save() def test_default_value_callable(): @@ -201,7 +201,7 @@ def uid_generator(): class DefaultTestValueTwo(StructuredNodeAsync): uid = StringProperty(default=uid_generator, index=True) - a = DefaultTestValueTwo().save_async() + a = DefaultTestValueTwo().save() assert a.uid == "xx" @@ -219,9 +219,9 @@ class DefaultTestValueThree(StructuredNodeAsync): x = DefaultTestValueThree() assert x.uid == "123" - x.save_async() + x.save() assert x.uid == "123" - x.refresh_async() + x.refresh() assert x.uid == "123" @@ -231,7 +231,7 @@ class TestDBNamePropertyNode(StructuredNodeAsync): x = TestDBNamePropertyNode() x.name_ = "jim" - x.save_async() + x.save() # check database property name on low level results, meta = adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") @@ -245,7 +245,7 @@ class TestDBNamePropertyNode(StructuredNodeAsync): assert TestDBNamePropertyNode.nodes.filter(name_="jim").all()[0].name_ == x.name_ assert TestDBNamePropertyNode.nodes.get(name_="jim").name_ == x.name_ - x.delete_async() + x.delete() def test_independent_property_name_get_or_create(): @@ -254,9 +254,9 @@ class TestNode(StructuredNodeAsync): name_ = StringProperty(db_property="name", required=True) # create the node - TestNode.get_or_create_async({"uid": 123, "name_": "jim"}) + TestNode.get_or_create({"uid": 123, "name_": "jim"}) # test that the node is retrieved correctly - x = TestNode.get_or_create_async({"uid": 123, "name_": "jim"})[0] + x = TestNode.get_or_create({"uid": 123, "name_": "jim"})[0] # check database property name on low level results, meta = adb.cypher_query("MATCH (n:TestNode) RETURN n") @@ -265,7 +265,7 @@ class TestNode(StructuredNodeAsync): assert "name_" not in node_properties # delete node afterwards - x.delete_async() + x.delete() @mark.parametrize("normalized_class", (NormalizedProperty,)) @@ -341,7 +341,7 @@ def test_uid_property(): class CheckMyId(StructuredNodeAsync): uid = UniqueIdProperty() - cmid = CheckMyId().save_async() + cmid = CheckMyId().save() assert len(cmid.uid) @@ -353,20 +353,20 @@ class ArrayProps(StructuredNodeAsync): def test_array_properties(): # untyped - ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save_async() + ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save() assert "Tim" in ap1.untyped_arr ap1 = ArrayProps.nodes.get(uid="1") assert "Tim" in ap1.untyped_arr # typed try: - ArrayProps(uid="2", typed_arr=["a", "b"]).save_async() + ArrayProps(uid="2", typed_arr=["a", "b"]).save() except DeflateError as e: assert "unsaved node" in str(e) else: assert False, "DeflateError not raised." - ap2 = ArrayProps(uid="2", typed_arr=[1, 2]).save_async() + ap2 = ArrayProps(uid="2", typed_arr=[1, 2]).save() assert 1 in ap2.typed_arr ap2 = ArrayProps.nodes.get(uid="2") assert 2 in ap2.typed_arr @@ -381,7 +381,7 @@ def test_indexed_array(): class IndexArray(StructuredNodeAsync): ai = ArrayProperty(unique_index=True) - b = IndexArray(ai=[1, 2]).save_async() + b = IndexArray(ai=[1, 2]).save() c = IndexArray.nodes.get(ai=[1, 2]) assert b.element_id == c.element_id @@ -396,14 +396,14 @@ class ConstrainedTestNode(StructuredNodeAsync): # Create a node with a missing required property with raises(RequiredProperty): x = ConstrainedTestNode(required_property="required", unique_property="unique") - x.save_async() + x.save() # Create a node with a missing unique (but not required) property. x = ConstrainedTestNode() x.required_property = "required" x.unique_required_property = "unique and required" x.unconstrained_property = "no contraints" - x.save_async() + x.save() # check database property name on low level results, meta = adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") @@ -411,7 +411,7 @@ class ConstrainedTestNode(StructuredNodeAsync): assert node_properties["unique_required_property"] == "unique and required" # delete node afterwards - x.delete_async() + x.delete() def test_unique_index_prop_enforced(): @@ -420,22 +420,22 @@ class UniqueNullableNameNode(StructuredNodeAsync): # Nameless x = UniqueNullableNameNode() - x.save_async() + x.save() y = UniqueNullableNameNode() - y.save_async() + y.save() # Named z = UniqueNullableNameNode(name="named") - z.save_async() + z.save() with raises(UniqueProperty): a = UniqueNullableNameNode(name="named") - a.save_async() + a.save() # Check nodes are in database results, meta = adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") assert len(results) == 3 # Delete nodes afterwards - x.delete_async() - y.delete_async() - z.delete_async() + x.delete() + y.delete() + z.delete() diff --git a/test/test_relationship_models.py b/test/test_relationship_models.py index 59669ae9..2377caf4 100644 --- a/test/test_relationship_models.py +++ b/test/test_relationship_models.py @@ -42,8 +42,8 @@ class Stoat(StructuredNodeAsync): def test_either_connect_with_rel_model(): - paul = Badger(name="Paul").save_async() - tom = Badger(name="Tom").save_async() + paul = Badger(name="Paul").save() + tom = Badger(name="Tom").save() # creating rels new_rel = tom.friend.disconnect(paul) @@ -64,8 +64,8 @@ def test_either_connect_with_rel_model(): def test_direction_connect_with_rel_model(): - paul = Badger(name="Paul the badger").save_async() - ian = Stoat(name="Ian the stoat").save_async() + paul = Badger(name="Paul the badger").save() + ian = Stoat(name="Ian the stoat").save() rel = ian.hates.connect(paul, {"reason": "thinks paul should bath more often"}) assert isinstance(rel.since, datetime) @@ -104,9 +104,9 @@ def test_direction_connect_with_rel_model(): def test_traversal_where_clause(): - phill = Badger(name="Phill the badger").save_async() - tim = Badger(name="Tim the badger").save_async() - bob = Badger(name="Bob the badger").save_async() + phill = Badger(name="Phill the badger").save() + tim = Badger(name="Tim the badger").save() + bob = Badger(name="Bob the badger").save() rel = tim.friend.connect(bob) now = datetime.now(pytz.utc) assert rel.since < now @@ -118,8 +118,8 @@ def test_traversal_where_clause(): def test_multiple_rels_exist_issue_223(): # check a badger can dislike a stoat for multiple reasons - phill = Badger(name="Phill").save_async() - ian = Stoat(name="Stoat").save_async() + phill = Badger(name="Phill").save() + ian = Stoat(name="Stoat").save() rel_a = phill.hates.connect(ian, {"reason": "a"}) rel_b = phill.hates.connect(ian, {"reason": "b"}) @@ -131,8 +131,8 @@ def test_multiple_rels_exist_issue_223(): def test_retrieve_all_rels(): - tom = Badger(name="tom").save_async() - ian = Stoat(name="ian").save_async() + tom = Badger(name="tom").save() + ian = Stoat(name="ian").save() rel_a = tom.hates.connect(ian, {"reason": "a"}) rel_b = tom.hates.connect(ian, {"reason": "b"}) @@ -147,8 +147,8 @@ def test_save_hook_on_rel_model(): HOOKS_CALLED["pre_save"] = 0 HOOKS_CALLED["post_save"] = 0 - paul = Badger(name="PaulB").save_async() - ian = Stoat(name="IanS").save_async() + paul = Badger(name="PaulB").save() + ian = Stoat(name="IanS").save() rel = ian.hates.connect(paul, {"reason": "yadda yadda"}) rel.save() diff --git a/test/test_relationships.py b/test/test_relationships.py index 75c98c90..4ee4b632 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -42,8 +42,8 @@ def special_power(self): def test_actions_on_deleted_node(): - u = PersonWithRels(name="Jim2", age=3).save_async() - u.delete_async() + u = PersonWithRels(name="Jim2", age=3).save() + u.delete() with raises(ValueError): u.is_from.connect(None) @@ -51,14 +51,14 @@ def test_actions_on_deleted_node(): u.is_from.get() with raises(ValueError): - u.save_async() + u.save() def test_bidirectional_relationships(): - u = PersonWithRels(name="Jim", age=3).save_async() + u = PersonWithRels(name="Jim", age=3).save() assert u - de = Country(code="DE").save_async() + de = Country(code="DE").save() assert de assert not u.is_from @@ -82,15 +82,15 @@ def test_bidirectional_relationships(): def test_either_direction_connect(): - rey = PersonWithRels(name="Rey", age=3).save_async() - sakis = PersonWithRels(name="Sakis", age=3).save_async() + rey = PersonWithRels(name="Rey", age=3).save() + sakis = PersonWithRels(name="Sakis", age=3).save() rey.knows.connect(sakis) assert rey.knows.is_connected(sakis) assert sakis.knows.is_connected(rey) sakis.knows.connect(rey) - result, _ = sakis.cypher_async( + result, _ = sakis.cypher( f"""MATCH (us), (them) WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", @@ -106,10 +106,10 @@ def test_either_direction_connect(): def test_search_and_filter_and_exclude(): - fred = PersonWithRels(name="Fred", age=13).save_async() - zz = Country(code="ZZ").save_async() - zx = Country(code="ZX").save_async() - zt = Country(code="ZY").save_async() + fred = PersonWithRels(name="Fred", age=13).save() + zz = Country(code="ZZ").save() + zx = Country(code="ZX").save() + zt = Country(code="ZY").save() fred.is_from.connect(zz) fred.is_from.connect(zx) fred.is_from.connect(zt) @@ -130,21 +130,21 @@ def test_search_and_filter_and_exclude(): def test_custom_methods(): - u = PersonWithRels(name="Joe90", age=13).save_async() + u = PersonWithRels(name="Joe90", age=13).save() assert u.special_power() == "I have no powers" - u = SuperHero(name="Joe91", age=13, power="xxx").save_async() + u = SuperHero(name="Joe91", age=13, power="xxx").save() assert u.special_power() == "I have powers" assert u.special_name == "Joe91" def test_valid_reconnection(): - p = PersonWithRels(name="ElPresidente", age=93).save_async() + p = PersonWithRels(name="ElPresidente", age=93).save() assert p - pp = PersonWithRels(name="TheAdversary", age=33).save_async() + pp = PersonWithRels(name="TheAdversary", age=33).save() assert pp - c = Country(code="CU").save_async() + c = Country(code="CU").save() assert c c.president.connect(p) @@ -160,16 +160,16 @@ def test_valid_reconnection(): def test_valid_replace(): - brady = PersonWithRels(name="Tom Brady", age=40).save_async() + brady = PersonWithRels(name="Tom Brady", age=40).save() assert brady - gronk = PersonWithRels(name="Rob Gronkowski", age=28).save_async() + gronk = PersonWithRels(name="Rob Gronkowski", age=28).save() assert gronk - colbert = PersonWithRels(name="Stephen Colbert", age=53).save_async() + colbert = PersonWithRels(name="Stephen Colbert", age=53).save() assert colbert - hanks = PersonWithRels(name="Tom Hanks", age=61).save_async() + hanks = PersonWithRels(name="Tom Hanks", age=61).save() assert hanks brady.knows.connect(gronk) @@ -186,13 +186,13 @@ def test_valid_replace(): def test_props_relationship(): - u = PersonWithRels(name="Mar", age=20).save_async() + u = PersonWithRels(name="Mar", age=20).save() assert u - c = Country(code="AT").save_async() + c = Country(code="AT").save() assert c - c2 = Country(code="LA").save_async() + c2 = Country(code="LA").save() assert c2 with raises(NotImplementedError): diff --git a/test/test_relative_relationships.py b/test/test_relative_relationships.py index 81cca8fd..7bd9bc21 100644 --- a/test/test_relative_relationships.py +++ b/test/test_relative_relationships.py @@ -9,10 +9,10 @@ class Cat(StructuredNodeAsync): def test_relative_relationship(): - a = Cat(name="snufkin").save_async() + a = Cat(name="snufkin").save() assert a - c = Country(code="MG").save_async() + c = Country(code="MG").save() assert c # connecting an instance of the class defined above diff --git a/test/test_scripts.py b/test/test_scripts.py index cd182bb8..ccc24420 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -41,7 +41,7 @@ def test_neomodel_install_labels(): ) assert result.returncode == 0 assert "Setting up indexes and constraints" in result.stdout - constraints = adb.list_constraints_async() + constraints = adb.list_constraints() parsed_constraints = [ (element["type"], element["labelsOrTypes"], element["properties"]) for element in constraints @@ -53,7 +53,7 @@ def test_neomodel_install_labels(): ["REL"], ["some_unique_property"], ) in parsed_constraints - indexes = adb.lise_indexes_async() + indexes = adb.list_indexes() parsed_indexes = [ (element["labelsOrTypes"], element["properties"]) for element in indexes ] @@ -81,8 +81,8 @@ def test_neomodel_remove_labels(): "Dropping unique constraint and index on label ScriptsTestNode" in result.stdout ) assert result.returncode == 0 - constraints = adb.list_constraints_async() - indexes = adb.lise_indexes_async(exclude_token_lookup=True) + constraints = adb.list_constraints() + indexes = adb.list_indexes(exclude_token_lookup=True) assert len(constraints) == 0 assert len(indexes) == 0 @@ -98,13 +98,13 @@ def test_neomodel_inspect_database(): assert "usage: neomodel_inspect_database" in result.stdout assert result.returncode == 0 - adb.clear_neo4j_database_async() - adb.install_labels_async(ScriptsTestNode) - adb.install_labels_async(ScriptsTestRel) + adb.clear_neo4j_database() + adb.install_labels(ScriptsTestNode) + adb.install_labels(ScriptsTestRel) # Create a few nodes and a rel, with indexes and constraints - node1 = ScriptsTestNode(personal_id="1", name="test").save_async() - node2 = ScriptsTestNode(personal_id="2", name="test").save_async() + node1 = ScriptsTestNode(personal_id="1", name="test").save() + node2 = ScriptsTestNode(personal_id="2", name="test").save() node1.rel.connect(node2, {"some_unique_property": "1", "some_index_property": "2"}) # Create a node with all the parsable property types diff --git a/test/test_transactions.py b/test/test_transactions.py index f7780c6d..b8d12520 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -15,16 +15,16 @@ def test_rollback_and_commit_transaction(): for p in APerson.nodes: p.delete() - APerson(name="Roger").save_async() + APerson(name="Roger").save() adb.begin() - APerson(name="Terry S").save_async() + APerson(name="Terry S").save() adb.rollback() assert len(APerson.nodes) == 1 adb.begin() - APerson(name="Terry S").save_async() + APerson(name="Terry S").save() adb.commit() assert len(APerson.nodes) == 2 @@ -33,11 +33,11 @@ def test_rollback_and_commit_transaction(): @adb.transaction def in_a_tx(*names): for n in names: - APerson(name=n).save_async() + APerson(name=n).save() def test_transaction_decorator(): - adb.install_labels_async(APerson) + adb.install_labels(APerson) for p in APerson.nodes: p.delete() @@ -54,13 +54,13 @@ def test_transaction_decorator(): def test_transaction_as_a_context(): with adb.transaction: - APerson(name="Tim").save_async() + APerson(name="Tim").save() assert APerson.nodes.filter(name="Tim") with raises(UniqueProperty): with adb.transaction: - APerson(name="Tim").save_async() + APerson(name="Tim").save() def test_query_inside_transaction(): @@ -68,14 +68,14 @@ def test_query_inside_transaction(): p.delete() with adb.transaction: - APerson(name="Alice").save_async() - APerson(name="Bob").save_async() + APerson(name="Alice").save() + APerson(name="Bob").save() assert len([p.name for p in APerson.nodes]) == 2 def test_read_transaction(): - APerson(name="Johnny").save_async() + APerson(name="Johnny").save() with adb.read_transaction: people = APerson.nodes.all() @@ -84,13 +84,13 @@ def test_read_transaction(): with raises(TransactionError): with adb.read_transaction: with raises(ClientError) as e: - APerson(name="Gina").save_async() + APerson(name="Gina").save() assert e.value.code == "Neo.ClientError.Statement.AccessMode" def test_write_transaction(): with adb.write_transaction: - APerson(name="Amelia").save_async() + APerson(name="Amelia").save() amelia = APerson.nodes.get(name="Amelia") assert amelia @@ -107,7 +107,7 @@ def double_transaction(): @adb.transaction.with_bookmark def in_a_tx(*names): for n in names: - APerson(name=n).save_async() + APerson(name=n).save() def test_bookmark_transaction_decorator(): @@ -128,14 +128,14 @@ def test_bookmark_transaction_decorator(): def test_bookmark_transaction_as_a_context(): with adb.transaction as transaction: - APerson(name="Tanya").save_async() + APerson(name="Tanya").save() assert isinstance(transaction.last_bookmark, Bookmarks) assert APerson.nodes.filter(name="Tanya") with raises(UniqueProperty): with adb.transaction as transaction: - APerson(name="Tanya").save_async() + APerson(name="Tanya").save() assert not hasattr(transaction, "last_bookmark") @@ -174,8 +174,8 @@ def test_query_inside_bookmark_transaction(): p.delete() with adb.transaction as transaction: - APerson(name="Alice").save_async() - APerson(name="Bob").save_async() + APerson(name="Alice").save() + APerson(name="Bob").save() assert len([p.name for p in APerson.nodes]) == 2 From 26a63c4b5b371aea1b16107833f86ba2df210535 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 12 Dec 2023 10:21:26 +0100 Subject: [PATCH 06/73] Move Async in class name to prefix --- neomodel/__init__.py | 3 ++- neomodel/_async/core.py | 20 ++++++++-------- neomodel/_sync/core.py | 20 ++++++++-------- neomodel/contrib/semi_structured.py | 4 ++-- neomodel/match.py | 16 ++++++------- test/async_/test_cypher.py | 8 +++---- test/sync/test_cypher.py | 8 +++---- test/test_alias.py | 4 ++-- test/test_batch.py | 10 ++++---- test/test_cardinality.py | 12 +++++----- test/test_connection.py | 4 ++-- test/test_contrib/test_spatial_properties.py | 6 ++--- test/test_database_management.py | 6 ++--- test/test_exceptions.py | 4 ++-- test/test_hooks.py | 4 ++-- test/test_indexing.py | 6 ++--- test/test_issue112.py | 4 ++-- test/test_issue283.py | 4 ++-- test/test_issue600.py | 4 ++-- test/test_label_drop.py | 4 ++-- test/test_label_install.py | 24 ++++++++++---------- test/test_match_api.py | 10 ++++---- test/test_migration_neo4j_5.py | 6 ++--- test/test_models.py | 24 ++++++++++---------- test/test_multiprocessing.py | 4 ++-- test/test_paths.py | 8 +++---- test/test_properties.py | 24 ++++++++++---------- test/test_relationship_models.py | 6 ++--- test/test_relationships.py | 6 ++--- test/test_relative_relationships.py | 4 ++-- test/test_scripts.py | 4 ++-- test/test_transactions.py | 4 ++-- 32 files changed, 138 insertions(+), 137 deletions(-) diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 5cf8ccdc..e6380f67 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,7 +1,7 @@ # pep8: noqa # TODO : Check imports here from neomodel._async.core import ( - StructuredNodeAsync, + AsyncStructuredNode, change_neo4j_password, clear_neo4j_database, drop_constraints, @@ -10,6 +10,7 @@ install_labels, remove_all_labels, ) +from neomodel._sync.core import StructuredNode from neomodel.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne from neomodel.exceptions import * from neomodel.match import EITHER, INCOMING, OUTGOING, NodeSet, Traversal diff --git a/neomodel/_async/core.py b/neomodel/_async/core.py index 7225816f..e2c07e66 100644 --- a/neomodel/_async/core.py +++ b/neomodel/_async/core.py @@ -212,15 +212,15 @@ def transaction(self): """ Returns the current transaction object """ - return TransactionProxyAsync(self) + return AsyncTransactionProxy(self) @property def write_transaction(self): - return TransactionProxyAsync(self, access_mode="WRITE") + return AsyncTransactionProxy(self, access_mode="WRITE") @property def read_transaction(self): - return TransactionProxyAsync(self, access_mode="READ") + return AsyncTransactionProxy(self, access_mode="READ") def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user @@ -661,7 +661,7 @@ def subsub(cls): # recursively return all subclasses stdout.write("Setting up indexes and constraints...\n\n") i = 0 - for cls in subsub(StructuredNodeAsync): + for cls in subsub(AsyncStructuredNode): stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") await install_labels(cls, quiet=False, stdout=stdout) i += 1 @@ -903,7 +903,7 @@ async def install_all_labels(stdout=None): await adb.install_all_labels(stdout) -class TransactionProxyAsync: +class AsyncTransactionProxy: bookmarks: Optional[Bookmarks] = None def __init__(self, db: AsyncDatabase, access_mode=None): @@ -938,7 +938,7 @@ def wrapper(*args, **kwargs): @property def with_bookmark(self): - return BookmarkingTransactionProxyAsync(self.db, self.access_mode) + return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) class ImpersonationHandler: @@ -965,7 +965,7 @@ def wrapper(*args, **kwargs): return wrapper -class BookmarkingTransactionProxyAsync(TransactionProxyAsync): +class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): def __call__(self, func): def wrapper(*args, **kwargs): self.bookmarks = kwargs.pop("bookmarks", None) @@ -1067,7 +1067,7 @@ def build_class_registry(cls): NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) -class StructuredNodeAsync(NodeBase): +class AsyncStructuredNode(NodeBase): """ Base class for all node definitions to inherit from. @@ -1091,7 +1091,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __eq__(self, other): - if not isinstance(other, (StructuredNodeAsync,)): + if not isinstance(other, (AsyncStructuredNode,)): return False if hasattr(self, "element_id") and hasattr(other, "element_id"): return self.element_id == other.element_id @@ -1168,7 +1168,7 @@ def _build_merge_query( query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " else: # validate relationship - if not isinstance(relationship.source, StructuredNodeAsync): + if not isinstance(relationship.source, AsyncStructuredNode): raise ValueError( f"relationship source [{repr(relationship.source)}] is not a StructuredNode" ) diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py index 7b408705..14eb4a14 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/_sync/core.py @@ -212,15 +212,15 @@ def transaction(self): """ Returns the current transaction object """ - return TransactionProxyAsync(self) + return TransactionProxy(self) @property def write_transaction(self): - return TransactionProxyAsync(self, access_mode="WRITE") + return TransactionProxy(self, access_mode="WRITE") @property def read_transaction(self): - return TransactionProxyAsync(self, access_mode="READ") + return TransactionProxy(self, access_mode="READ") def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user @@ -661,7 +661,7 @@ def subsub(cls): # recursively return all subclasses stdout.write("Setting up indexes and constraints...\n\n") i = 0 - for cls in subsub(StructuredNodeAsync): + for cls in subsub(StructuredNode): stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") install_labels(cls, quiet=False, stdout=stdout) i += 1 @@ -903,7 +903,7 @@ def install_all_labels(stdout=None): adb.install_all_labels(stdout) -class TransactionProxyAsync: +class TransactionProxy: bookmarks: Optional[Bookmarks] = None def __init__(self, db: Database, access_mode=None): @@ -938,7 +938,7 @@ def wrapper(*args, **kwargs): @property def with_bookmark(self): - return BookmarkingTransactionProxyAsync(self.db, self.access_mode) + return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) class ImpersonationHandler: @@ -965,7 +965,7 @@ def wrapper(*args, **kwargs): return wrapper -class BookmarkingTransactionProxyAsync(TransactionProxyAsync): +class BookmarkingAsyncTransactionProxy(TransactionProxy): def __call__(self, func): def wrapper(*args, **kwargs): self.bookmarks = kwargs.pop("bookmarks", None) @@ -1067,7 +1067,7 @@ def build_class_registry(cls): NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) -class StructuredNodeAsync(NodeBase): +class StructuredNode(NodeBase): """ Base class for all node definitions to inherit from. @@ -1091,7 +1091,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __eq__(self, other): - if not isinstance(other, (StructuredNodeAsync,)): + if not isinstance(other, (StructuredNode,)): return False if hasattr(self, "element_id") and hasattr(other, "element_id"): return self.element_id == other.element_id @@ -1168,7 +1168,7 @@ def _build_merge_query( query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " else: # validate relationship - if not isinstance(relationship.source, StructuredNodeAsync): + if not isinstance(relationship.source, StructuredNode): raise ValueError( f"relationship source [{repr(relationship.source)}] is not a StructuredNode" ) diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/semi_structured.py index 580514ba..869763dd 100644 --- a/neomodel/contrib/semi_structured.py +++ b/neomodel/contrib/semi_structured.py @@ -1,9 +1,9 @@ -from neomodel._async.core import StructuredNodeAsync +from neomodel._async.core import AsyncStructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict from neomodel.util import _get_node_properties -class SemiStructuredNode(StructuredNodeAsync): +class SemiStructuredNode(AsyncStructuredNode): """ A base class allowing properties to be stored on a node that aren't specified in its definition. Conflicting properties are signaled with the diff --git a/neomodel/match.py b/neomodel/match.py index 46e55f40..65ad99b9 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Optional -from neomodel._async.core import StructuredNodeAsync, adb +from neomodel._async.core import AsyncStructuredNode, adb from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty @@ -382,7 +382,7 @@ def build_source(self, source): return self.build_traversal(source) if isinstance(source, NodeSet): if inspect.isclass(source.source) and issubclass( - source.source, StructuredNodeAsync + source.source, AsyncStructuredNode ): ident = self.build_label(source.source.__label__.lower(), source.source) else: @@ -402,7 +402,7 @@ def build_source(self, source): ) return ident - if isinstance(source, StructuredNodeAsync): + if isinstance(source, AsyncStructuredNode): return self.build_node(source) raise ValueError("Unknown source type " + repr(source)) @@ -747,7 +747,7 @@ def __nonzero__(self): return self.query_cls(self).build_ast()._count() > 0 def __contains__(self, obj): - if isinstance(obj, StructuredNodeAsync): + if isinstance(obj, AsyncStructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: return self.query_cls(self).build_ast()._contains(obj.element_id) raise ValueError("Unsaved node: " + repr(obj)) @@ -791,9 +791,9 @@ def __init__(self, source): self.source = source # could be a Traverse object or a node class if isinstance(source, Traversal): self.source_class = source.target_class - elif inspect.isclass(source) and issubclass(source, StructuredNodeAsync): + elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): self.source_class = source - elif isinstance(source, StructuredNodeAsync): + elif isinstance(source, AsyncStructuredNode): self.source_class = source.__class__ else: raise ValueError("Bad source for nodeset " + repr(source)) @@ -995,9 +995,9 @@ def __init__(self, source, name, definition): if isinstance(source, Traversal): self.source_class = source.target_class - elif inspect.isclass(source) and issubclass(source, StructuredNodeAsync): + elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): self.source_class = source - elif isinstance(source, StructuredNodeAsync): + elif isinstance(source, AsyncStructuredNode): self.source_class = source.__class__ elif isinstance(source, NodeSet): self.source_class = source.source_class diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index 190b4f9e..f3cb6d7d 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -6,21 +6,21 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNodeAsync +from neomodel import AsyncStructuredNode, StringProperty from neomodel._async.core import adb -class User2(StructuredNodeAsync): +class User2(AsyncStructuredNode): name = StringProperty() email = StringProperty() -class UserPandas(StructuredNodeAsync): +class UserPandas(AsyncStructuredNode): name = StringProperty() email = StringProperty() -class UserNP(StructuredNodeAsync): +class UserNP(AsyncStructuredNode): name = StringProperty() email = StringProperty() diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 2875d640..a304c9fb 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -6,21 +6,21 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNodeAsync +from neomodel import StructuredNode, StringProperty from neomodel._async.core import adb -class User2(StructuredNodeAsync): +class User2(StructuredNode): name = StringProperty() email = StringProperty() -class UserPandas(StructuredNodeAsync): +class UserPandas(StructuredNode): name = StringProperty() email = StringProperty() -class UserNP(StructuredNodeAsync): +class UserNP(StructuredNode): name = StringProperty() email = StringProperty() diff --git a/test/test_alias.py b/test/test_alias.py index c0c4877f..6f810b03 100644 --- a/test/test_alias.py +++ b/test/test_alias.py @@ -1,4 +1,4 @@ -from neomodel import AliasProperty, StringProperty, StructuredNodeAsync +from neomodel import AliasProperty, AsyncStructuredNode, StringProperty class MagicProperty(AliasProperty): @@ -6,7 +6,7 @@ def setup(self): self.owner.setup_hook_called = True -class AliasTestNode(StructuredNodeAsync): +class AliasTestNode(AsyncStructuredNode): name = StringProperty(unique_index=True) full_name = AliasProperty(to="name") long_name = MagicProperty(to="name") diff --git a/test/test_batch.py b/test/test_batch.py index 3085de3d..fc582509 100644 --- a/test/test_batch.py +++ b/test/test_batch.py @@ -1,11 +1,11 @@ from pytest import raises from neomodel import ( + AsyncStructuredNode, IntegerProperty, RelationshipFrom, RelationshipTo, StringProperty, - StructuredNodeAsync, UniqueIdProperty, config, ) @@ -14,7 +14,7 @@ config.AUTO_INSTALL_LABELS = True -class UniqueUser(StructuredNodeAsync): +class UniqueUser(AsyncStructuredNode): uid = UniqueIdProperty() name = StringProperty() age = IntegerProperty() @@ -31,7 +31,7 @@ def test_unique_id_property_batch(): assert users[1].uid -class Customer(StructuredNodeAsync): +class Customer(AsyncStructuredNode): email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) @@ -97,12 +97,12 @@ def test_batch_index_violation(): assert not Customer.nodes.filter(email="jim7@aol.com") -class Dog(StructuredNodeAsync): +class Dog(AsyncStructuredNode): name = StringProperty(required=True) owner = RelationshipTo("Person", "owner") -class Person(StructuredNodeAsync): +class Person(AsyncStructuredNode): name = StringProperty(unique_index=True) pets = RelationshipFrom("Dog", "owner") diff --git a/test/test_cardinality.py b/test/test_cardinality.py index b21e652e..8a83c3ee 100644 --- a/test/test_cardinality.py +++ b/test/test_cardinality.py @@ -1,6 +1,7 @@ from pytest import raises from neomodel import ( + AsyncStructuredNode, AttemptedCardinalityViolation, CardinalityViolation, IntegerProperty, @@ -8,26 +9,25 @@ OneOrMore, RelationshipTo, StringProperty, - StructuredNodeAsync, ZeroOrMore, ZeroOrOne, adb, ) -class HairDryer(StructuredNodeAsync): +class HairDryer(AsyncStructuredNode): version = IntegerProperty() -class ScrewDriver(StructuredNodeAsync): +class ScrewDriver(AsyncStructuredNode): version = IntegerProperty() -class Car(StructuredNodeAsync): +class Car(AsyncStructuredNode): version = IntegerProperty() -class Monkey(StructuredNodeAsync): +class Monkey(AsyncStructuredNode): name = StringProperty() dryers = RelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) driver = RelationshipTo("ScrewDriver", "HAS_SCREWDRIVER", cardinality=ZeroOrOne) @@ -35,7 +35,7 @@ class Monkey(StructuredNodeAsync): toothbrush = RelationshipTo("ToothBrush", "HAS_TOOTHBRUSH", cardinality=One) -class ToothBrush(StructuredNodeAsync): +class ToothBrush(AsyncStructuredNode): name = StringProperty() diff --git a/test/test_connection.py b/test/test_connection.py index fb5fad42..4bb0f091 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -5,7 +5,7 @@ from neo4j import GraphDatabase from neo4j.debug import watch -from neomodel import StringProperty, StructuredNodeAsync, adb, config +from neomodel import AsyncStructuredNode, StringProperty, adb, config from neomodel.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME @@ -37,7 +37,7 @@ def get_current_database_name() -> str: return results_as_dict[0]["name"] -class Pastry(StructuredNodeAsync): +class Pastry(AsyncStructuredNode): name = StringProperty(unique_index=True) diff --git a/test/test_contrib/test_spatial_properties.py b/test/test_contrib/test_spatial_properties.py index b7986b58..89b8641a 100644 --- a/test/test_contrib/test_spatial_properties.py +++ b/test/test_contrib/test_spatial_properties.py @@ -166,7 +166,7 @@ def get_some_point(): (random.random(), random.random()) ) - class LocalisableEntity(neomodel.StructuredNodeAsync): + class LocalisableEntity(neomodel.AsyncStructuredNode): """ A very simple entity to try out the default value assignment. """ @@ -200,7 +200,7 @@ def test_array_of_points(): :return: """ - class AnotherLocalisableEntity(neomodel.StructuredNodeAsync): + class AnotherLocalisableEntity(neomodel.AsyncStructuredNode): """ A very simple entity with an array of locations """ @@ -242,7 +242,7 @@ def test_simple_storage_retrieval(): :return: """ - class TestStorageRetrievalProperty(neomodel.StructuredNodeAsync): + class TestStorageRetrievalProperty(neomodel.AsyncStructuredNode): uid = neomodel.UniqueIdProperty() description = neomodel.StringProperty() location = neomodel.contrib.spatial_properties.PointProperty(crs="cartesian") diff --git a/test/test_database_management.py b/test/test_database_management.py index 6da05818..1a277d16 100644 --- a/test/test_database_management.py +++ b/test/test_database_management.py @@ -2,16 +2,16 @@ from neo4j.exceptions import AuthError from neomodel import ( + AsyncStructuredNode, IntegerProperty, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, ) from neomodel._async.core import adb -class City(StructuredNodeAsync): +class City(AsyncStructuredNode): name = StringProperty() @@ -19,7 +19,7 @@ class InCity(StructuredRel): creation_year = IntegerProperty(index=True) -class Venue(StructuredNodeAsync): +class Venue(AsyncStructuredNode): name = StringProperty(unique_index=True) creator = StringProperty(index=True) in_city = RelationshipTo(City, relation_type="IN", model=InCity) diff --git a/test/test_exceptions.py b/test/test_exceptions.py index f631fa4b..c6976515 100644 --- a/test/test_exceptions.py +++ b/test/test_exceptions.py @@ -1,9 +1,9 @@ import pickle -from neomodel import DoesNotExist, StringProperty, StructuredNodeAsync +from neomodel import AsyncStructuredNode, DoesNotExist, StringProperty -class EPerson(StructuredNodeAsync): +class EPerson(AsyncStructuredNode): name = StringProperty(unique_index=True) diff --git a/test/test_hooks.py b/test/test_hooks.py index c77e2845..8fb9b8e5 100644 --- a/test/test_hooks.py +++ b/test/test_hooks.py @@ -1,9 +1,9 @@ -from neomodel import StringProperty, StructuredNodeAsync +from neomodel import AsyncStructuredNode, StringProperty HOOKS_CALLED = {} -class HookTest(StructuredNodeAsync): +class HookTest(AsyncStructuredNode): name = StringProperty() def post_create(self): diff --git a/test/test_indexing.py b/test/test_indexing.py index 9611df9a..5f1df506 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -2,16 +2,16 @@ from pytest import raises from neomodel import ( + AsyncStructuredNode, IntegerProperty, StringProperty, - StructuredNodeAsync, UniqueProperty, ) from neomodel._async.core import adb from neomodel.exceptions import ConstraintValidationFailed -class Human(StructuredNodeAsync): +class Human(AsyncStructuredNode): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) @@ -67,7 +67,7 @@ def test_does_not_exist(): def test_custom_label_name(): - class Giraffe(StructuredNodeAsync): + class Giraffe(AsyncStructuredNode): __label__ = "Giraffes" name = StringProperty(unique_index=True) diff --git a/test/test_issue112.py b/test/test_issue112.py index dd569fdc..c24fe1b2 100644 --- a/test/test_issue112.py +++ b/test/test_issue112.py @@ -1,7 +1,7 @@ -from neomodel import RelationshipTo, StructuredNodeAsync +from neomodel import AsyncStructuredNode, RelationshipTo -class SomeModel(StructuredNodeAsync): +class SomeModel(AsyncStructuredNode): test = RelationshipTo("SomeModel", "SELF") diff --git a/test/test_issue283.py b/test/test_issue283.py index d42d3fb8..ebdeb97d 100644 --- a/test/test_issue283.py +++ b/test/test_issue283.py @@ -36,7 +36,7 @@ class PersonalRelationship(neomodel.StructuredRel): on_date = neomodel.DateTimeProperty(default_now=True) -class BasePerson(neomodel.StructuredNodeAsync): +class BasePerson(neomodel.AsyncStructuredNode): """ Base class for defining some basic sort of an actor. """ @@ -64,7 +64,7 @@ class PilotPerson(BasePerson): airplane = neomodel.StringProperty(required=True) -class BaseOtherPerson(neomodel.StructuredNodeAsync): +class BaseOtherPerson(neomodel.AsyncStructuredNode): """ An obviously "wrong" class of actor to befriend BasePersons with. """ diff --git a/test/test_issue600.py b/test/test_issue600.py index c6d624a4..d26240f9 100644 --- a/test/test_issue600.py +++ b/test/test_issue600.py @@ -30,7 +30,7 @@ class SubClass2(Class1): pass -class RelationshipDefinerSecondSibling(neomodel.StructuredNodeAsync): +class RelationshipDefinerSecondSibling(neomodel.AsyncStructuredNode): rel_1 = neomodel.Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 ) @@ -42,7 +42,7 @@ class RelationshipDefinerSecondSibling(neomodel.StructuredNodeAsync): ) -class RelationshipDefinerParentLast(neomodel.StructuredNodeAsync): +class RelationshipDefinerParentLast(neomodel.AsyncStructuredNode): rel_2 = neomodel.Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 ) diff --git a/test/test_label_drop.py b/test/test_label_drop.py index e62f5caf..aadbad0c 100644 --- a/test/test_label_drop.py +++ b/test/test_label_drop.py @@ -1,12 +1,12 @@ from neo4j.exceptions import ClientError -from neomodel import StringProperty, StructuredNodeAsync, config +from neomodel import AsyncStructuredNode, StringProperty, config from neomodel._async.core import adb config.AUTO_INSTALL_LABELS = True -class ConstraintAndIndex(StructuredNodeAsync): +class ConstraintAndIndex(AsyncStructuredNode): name = StringProperty(unique_index=True) last_name = StringProperty(index=True) diff --git a/test/test_label_install.py b/test/test_label_install.py index abbff9a5..256ed1bd 100644 --- a/test/test_label_install.py +++ b/test/test_label_install.py @@ -1,9 +1,9 @@ import pytest from neomodel import ( + AsyncStructuredNode, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, UniqueIdProperty, config, @@ -14,15 +14,15 @@ config.AUTO_INSTALL_LABELS = False -class NodeWithIndex(StructuredNodeAsync): +class NodeWithIndex(AsyncStructuredNode): name = StringProperty(index=True) -class NodeWithConstraint(StructuredNodeAsync): +class NodeWithConstraint(AsyncStructuredNode): name = StringProperty(unique_index=True) -class NodeWithRelationship(StructuredNodeAsync): +class NodeWithRelationship(AsyncStructuredNode): ... @@ -30,18 +30,18 @@ class IndexedRelationship(StructuredRel): indexed_rel_prop = StringProperty(index=True) -class OtherNodeWithRelationship(StructuredNodeAsync): +class OtherNodeWithRelationship(AsyncStructuredNode): has_rel = RelationshipTo( NodeWithRelationship, "INDEXED_REL", model=IndexedRelationship ) -class AbstractNode(StructuredNodeAsync): +class AbstractNode(AsyncStructuredNode): __abstract_node__ = True name = StringProperty(unique_index=True) -class SomeNotUniqueNode(StructuredNodeAsync): +class SomeNotUniqueNode(AsyncStructuredNode): id_ = UniqueIdProperty(db_property="id") @@ -104,7 +104,7 @@ def test_install_label_twice(capsys): class UniqueIndexRelationship(StructuredRel): unique_index_rel_prop = StringProperty(unique_index=True) - class OtherNodeWithUniqueIndexRelationship(StructuredNodeAsync): + class OtherNodeWithUniqueIndexRelationship(AsyncStructuredNode): has_rel = RelationshipTo( NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship ) @@ -136,14 +136,14 @@ def test_relationship_unique_index_not_supported(): class UniqueIndexRelationship(StructuredRel): name = StringProperty(unique_index=True) - class TargetNodeForUniqueIndexRelationship(StructuredNodeAsync): + class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): pass with pytest.raises( FeatureNotSupported, match=r".*Please upgrade to Neo4j 5.7 or higher" ): - class NodeWithUniqueIndexRelationship(StructuredNodeAsync): + class NodeWithUniqueIndexRelationship(AsyncStructuredNode): has_rel = RelationshipTo( TargetNodeForUniqueIndexRelationship, "UNIQUE_INDEX_REL", @@ -156,10 +156,10 @@ def test_relationship_unique_index(): class UniqueIndexRelationshipBis(StructuredRel): name = StringProperty(unique_index=True) - class TargetNodeForUniqueIndexRelationship(StructuredNodeAsync): + class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): pass - class NodeWithUniqueIndexRelationship(StructuredNodeAsync): + class NodeWithUniqueIndexRelationship(AsyncStructuredNode): has_rel = RelationshipTo( TargetNodeForUniqueIndexRelationship, "UNIQUE_INDEX_REL_BIS", diff --git a/test/test_match_api.py b/test/test_match_api.py index e6e521d5..50828523 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -4,13 +4,13 @@ from neomodel import ( INCOMING, + AsyncStructuredNode, DateTimeProperty, IntegerProperty, Q, RelationshipFrom, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, ) from neomodel.exceptions import MultipleNodesReturned @@ -22,18 +22,18 @@ class SupplierRel(StructuredRel): courier = StringProperty() -class Supplier(StructuredNodeAsync): +class Supplier(AsyncStructuredNode): name = StringProperty() delivery_cost = IntegerProperty() coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS") -class Species(StructuredNodeAsync): +class Species(AsyncStructuredNode): name = StringProperty() coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) -class Coffee(StructuredNodeAsync): +class Coffee(AsyncStructuredNode): name = StringProperty(unique_index=True) price = IntegerProperty() suppliers = RelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) @@ -41,7 +41,7 @@ class Coffee(StructuredNodeAsync): id_ = IntegerProperty() -class Extension(StructuredNodeAsync): +class Extension(AsyncStructuredNode): extension = RelationshipTo("Extension", "extension") diff --git a/test/test_migration_neo4j_5.py b/test/test_migration_neo4j_5.py index 7f36a619..a5efe800 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/test_migration_neo4j_5.py @@ -1,16 +1,16 @@ import pytest from neomodel import ( + AsyncStructuredNode, IntegerProperty, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, ) from neomodel._async.core import adb -class Album(StructuredNodeAsync): +class Album(AsyncStructuredNode): name = StringProperty() @@ -18,7 +18,7 @@ class Released(StructuredRel): year = IntegerProperty() -class Band(StructuredNodeAsync): +class Band(AsyncStructuredNode): name = StringProperty() released = RelationshipTo(Album, relation_type="RELEASED", model=Released) diff --git a/test/test_models.py b/test/test_models.py index 028150b0..fb773055 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -5,17 +5,17 @@ from pytest import raises from neomodel import ( + AsyncStructuredNode, DateProperty, IntegerProperty, StringProperty, - StructuredNodeAsync, StructuredRel, ) from neomodel._async.core import adb from neomodel.exceptions import RequiredProperty, UniqueProperty -class User(StructuredNodeAsync): +class User(AsyncStructuredNode): email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) @@ -28,12 +28,12 @@ def email_alias(self, value): self.email = value -class NodeWithoutProperty(StructuredNodeAsync): +class NodeWithoutProperty(AsyncStructuredNode): pass def test_issue_233(): - class BaseIssue233(StructuredNodeAsync): + class BaseIssue233(AsyncStructuredNode): __abstract_node__ = True def __getitem__(self, item): @@ -156,7 +156,7 @@ def test_save_through_magic_property(): assert user2 -class Customer2(StructuredNodeAsync): +class Customer2(AsyncStructuredNode): __label__ = "customers" email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) @@ -229,7 +229,7 @@ def test_setting_value_to_none(): def test_inheritance(): - class User(StructuredNodeAsync): + class User(AsyncStructuredNode): __abstract_node__ = True name = StringProperty(unique_index=True) @@ -251,7 +251,7 @@ def credit_account(self, amount): def test_inherited_optional_labels(): - class BaseOptional(StructuredNodeAsync): + class BaseOptional(AsyncStructuredNode): __optional_labels__ = ["Alive"] name = StringProperty(unique_index=True) @@ -286,7 +286,7 @@ def credit_account(self, amount): self.balance = self.balance + int(amount) self.save() - class Shopper2(StructuredNodeAsync, UserMixin, CreditMixin): + class Shopper2(AsyncStructuredNode, UserMixin, CreditMixin): pass jim = Shopper2(name="jimmy", balance=300).save() @@ -300,7 +300,7 @@ class Shopper2(StructuredNodeAsync, UserMixin, CreditMixin): def test_date_property(): - class DateTest(StructuredNodeAsync): + class DateTest(AsyncStructuredNode): birthdate = DateProperty() user = DateTest(birthdate=datetime.now()).save() @@ -310,17 +310,17 @@ def test_reserved_property_keys(): error_match = r".*is not allowed as it conflicts with neomodel internals.*" with raises(ValueError, match=error_match): - class ReservedPropertiesDeletedNode(StructuredNodeAsync): + class ReservedPropertiesDeletedNode(AsyncStructuredNode): deleted = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesIdNode(StructuredNodeAsync): + class ReservedPropertiesIdNode(AsyncStructuredNode): id = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesElementIdNode(StructuredNodeAsync): + class ReservedPropertiesElementIdNode(AsyncStructuredNode): element_id = StringProperty() with raises(ValueError, match=error_match): diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 8b637bff..28f1422c 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -1,9 +1,9 @@ from multiprocessing.pool import ThreadPool as Pool -from neomodel import StringProperty, StructuredNodeAsync, adb +from neomodel import AsyncStructuredNode, StringProperty, adb -class ThingyMaBob(StructuredNodeAsync): +class ThingyMaBob(AsyncStructuredNode): name = StringProperty(unique_index=True, required=True) diff --git a/test/test_paths.py b/test/test_paths.py index 4b6bf447..f6f8bbbc 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -1,9 +1,9 @@ from neomodel import ( + AsyncStructuredNode, IntegerProperty, NeomodelPath, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, UniqueIdProperty, adb, @@ -18,16 +18,16 @@ class PersonLivesInCity(StructuredRel): some_num = IntegerProperty(index=True, default=12) -class CountryOfOrigin(StructuredNodeAsync): +class CountryOfOrigin(AsyncStructuredNode): code = StringProperty(unique_index=True, required=True) -class CityOfResidence(StructuredNodeAsync): +class CityOfResidence(AsyncStructuredNode): name = StringProperty(required=True) country = RelationshipTo(CountryOfOrigin, "FROM_COUNTRY") -class PersonOfInterest(StructuredNodeAsync): +class PersonOfInterest(AsyncStructuredNode): uid = UniqueIdProperty() name = StringProperty(unique_index=True) age = IntegerProperty(index=True, default=0) diff --git a/test/test_properties.py b/test/test_properties.py index 76ae1bcc..d03f3284 100644 --- a/test/test_properties.py +++ b/test/test_properties.py @@ -3,7 +3,7 @@ from pytest import mark, raises from pytz import timezone -from neomodel import StructuredNodeAsync, adb, config +from neomodel import AsyncStructuredNode, adb, config from neomodel.exceptions import ( DeflateError, InflateError, @@ -61,7 +61,7 @@ def test_string_property_exceeds_max_length(): def test_string_property_w_choice(): - class TestChoices(StructuredNodeAsync): + class TestChoices(AsyncStructuredNode): SEXES = {"F": "Female", "M": "Male", "O": "Other"} sex = StringProperty(required=True, choices=SEXES) @@ -186,7 +186,7 @@ def test_json(): def test_default_value(): - class DefaultTestValue(StructuredNodeAsync): + class DefaultTestValue(AsyncStructuredNode): name_xx = StringProperty(default="jim", index=True) a = DefaultTestValue() @@ -198,7 +198,7 @@ def test_default_value_callable(): def uid_generator(): return "xx" - class DefaultTestValueTwo(StructuredNodeAsync): + class DefaultTestValueTwo(AsyncStructuredNode): uid = StringProperty(default=uid_generator, index=True) a = DefaultTestValueTwo().save() @@ -214,7 +214,7 @@ def __str__(self): return Foo() - class DefaultTestValueThree(StructuredNodeAsync): + class DefaultTestValueThree(AsyncStructuredNode): uid = StringProperty(default=factory, index=True) x = DefaultTestValueThree() @@ -226,7 +226,7 @@ class DefaultTestValueThree(StructuredNodeAsync): def test_independent_property_name(): - class TestDBNamePropertyNode(StructuredNodeAsync): + class TestDBNamePropertyNode(AsyncStructuredNode): name_ = StringProperty(db_property="name") x = TestDBNamePropertyNode() @@ -249,7 +249,7 @@ class TestDBNamePropertyNode(StructuredNodeAsync): def test_independent_property_name_get_or_create(): - class TestNode(StructuredNodeAsync): + class TestNode(AsyncStructuredNode): uid = UniqueIdProperty() name_ = StringProperty(db_property="name", required=True) @@ -338,14 +338,14 @@ def test_uid_property(): myuid = prop.default_value() assert len(myuid) - class CheckMyId(StructuredNodeAsync): + class CheckMyId(AsyncStructuredNode): uid = UniqueIdProperty() cmid = CheckMyId().save() assert len(cmid.uid) -class ArrayProps(StructuredNodeAsync): +class ArrayProps(AsyncStructuredNode): uid = StringProperty(unique_index=True) untyped_arr = ArrayProperty() typed_arr = ArrayProperty(IntegerProperty()) @@ -378,7 +378,7 @@ def test_illegal_array_base_prop_raises(): def test_indexed_array(): - class IndexArray(StructuredNodeAsync): + class IndexArray(AsyncStructuredNode): ai = ArrayProperty(unique_index=True) b = IndexArray(ai=[1, 2]).save() @@ -387,7 +387,7 @@ class IndexArray(StructuredNodeAsync): def test_unique_index_prop_not_required(): - class ConstrainedTestNode(StructuredNodeAsync): + class ConstrainedTestNode(AsyncStructuredNode): required_property = StringProperty(required=True) unique_property = StringProperty(unique_index=True) unique_required_property = StringProperty(unique_index=True, required=True) @@ -415,7 +415,7 @@ class ConstrainedTestNode(StructuredNodeAsync): def test_unique_index_prop_enforced(): - class UniqueNullableNameNode(StructuredNodeAsync): + class UniqueNullableNameNode(AsyncStructuredNode): name = StringProperty(unique_index=True) # Nameless diff --git a/test/test_relationship_models.py b/test/test_relationship_models.py index 2377caf4..2e07b684 100644 --- a/test/test_relationship_models.py +++ b/test/test_relationship_models.py @@ -4,12 +4,12 @@ from pytest import raises from neomodel import ( + AsyncStructuredNode, DateTimeProperty, DeflateError, Relationship, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, ) @@ -30,13 +30,13 @@ def post_save(self): HOOKS_CALLED["post_save"] += 1 -class Badger(StructuredNodeAsync): +class Badger(AsyncStructuredNode): name = StringProperty(unique_index=True) friend = Relationship("Badger", "FRIEND", model=FriendRel) hates = RelationshipTo("Stoat", "HATES", model=HatesRel) -class Stoat(StructuredNodeAsync): +class Stoat(AsyncStructuredNode): name = StringProperty(unique_index=True) hates = RelationshipTo("Badger", "HATES", model=HatesRel) diff --git a/test/test_relationships.py b/test/test_relationships.py index 4ee4b632..4c047eaf 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -1,6 +1,7 @@ from pytest import raises from neomodel import ( + AsyncStructuredNode, IntegerProperty, One, Q, @@ -8,13 +9,12 @@ RelationshipFrom, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, ) from neomodel._async.core import adb -class PersonWithRels(StructuredNodeAsync): +class PersonWithRels(AsyncStructuredNode): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) is_from = RelationshipTo("Country", "IS_FROM") @@ -28,7 +28,7 @@ def special_power(self): return "I have no powers" -class Country(StructuredNodeAsync): +class Country(AsyncStructuredNode): code = StringProperty(unique_index=True) inhabitant = RelationshipFrom(PersonWithRels, "IS_FROM") president = RelationshipTo(PersonWithRels, "PRESIDENT", cardinality=One) diff --git a/test/test_relative_relationships.py b/test/test_relative_relationships.py index 7bd9bc21..0a82ff57 100644 --- a/test/test_relative_relationships.py +++ b/test/test_relative_relationships.py @@ -1,8 +1,8 @@ -from neomodel import RelationshipTo, StringProperty, StructuredNodeAsync +from neomodel import AsyncStructuredNode, RelationshipTo, StringProperty from neomodel.test_relationships import Country -class Cat(StructuredNodeAsync): +class Cat(AsyncStructuredNode): name = StringProperty() # Relationship is defined using a relative class path is_from = RelationshipTo(".test_relationships.Country", "IS_FROM") diff --git a/test/test_scripts.py b/test/test_scripts.py index c9543675..f925603e 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -3,9 +3,9 @@ import pytest from neomodel import ( + AsyncStructuredNode, RelationshipTo, StringProperty, - StructuredNodeAsync, StructuredRel, config, ) @@ -19,7 +19,7 @@ class ScriptsTestRel(StructuredRel): some_index_property = StringProperty(index=True) -class ScriptsTestNode(StructuredNodeAsync): +class ScriptsTestNode(AsyncStructuredNode): personal_id = StringProperty(unique_index=True) name = StringProperty(index=True) rel = RelationshipTo("ScriptsTestNode", "REL", model=ScriptsTestRel) diff --git a/test/test_transactions.py b/test/test_transactions.py index b8d12520..4bdec8af 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -3,11 +3,11 @@ from neo4j.exceptions import ClientError, TransactionError from pytest import raises -from neomodel import StringProperty, StructuredNodeAsync, UniqueProperty +from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty from neomodel._async.core import adb -class APerson(StructuredNodeAsync): +class APerson(AsyncStructuredNode): name = StringProperty(unique_index=True) From 98904077cbe6ec51ce3691bca027fcd2806a2f65 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 15 Dec 2023 16:10:20 +0100 Subject: [PATCH 07/73] Fix token replacements --- bin/make-unasync | 1 - 1 file changed, 1 deletion(-) diff --git a/bin/make-unasync b/bin/make-unasync index 9be3b8bc..c66d8d49 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -77,7 +77,6 @@ class CustomRule(unasync.Rule): def __init__(self, *args, **kwargs): super(CustomRule, self).__init__(*args, **kwargs) self.out_files = [] - self.token_replacements = {} def _unasync_tokens(self, tokens): # copy from unasync to fix handling of multiline strings From 12b0601f6d4a0ad478e82d4aba0e58fd98fce1f4 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 15 Dec 2023 16:11:17 +0100 Subject: [PATCH 08/73] Fix sync version --- neomodel/_sync/core.py | 70 ++++++++++++++++++++-------------------- test/sync/conftest.py | 22 ++++++------- test/sync/test_cypher.py | 36 ++++++++++----------- 3 files changed, 64 insertions(+), 64 deletions(-) diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py index 14eb4a14..3c7897c2 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/_sync/core.py @@ -820,15 +820,15 @@ def _install_relationship(self, cls, relationship, quiet, stdout): # Create a singleton instance of the database object -adb = Database() +db = Database() # Deprecated methods def change_neo4j_password(db: Database, user, new_password): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.change_neo4j_password(user, new_password) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.change_neo4j_password(user, new_password) instead. This direct call will be removed in an upcoming version. """ ) @@ -840,8 +840,8 @@ def clear_neo4j_database( ): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.clear_neo4j_database(clear_constraints, clear_indexes) instead. This direct call will be removed in an upcoming version. """ ) @@ -851,56 +851,56 @@ def clear_neo4j_database( def drop_constraints(quiet=True, stdout=None): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_constraints(quiet, stdout) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.drop_constraints(quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.drop_constraints(quiet, stdout) + db.drop_constraints(quiet, stdout) def drop_indexes(quiet=True, stdout=None): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.drop_indexes(quiet, stdout) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.drop_indexes(quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.drop_indexes(quiet, stdout) + db.drop_indexes(quiet, stdout) def remove_all_labels(stdout=None): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.remove_all_labels(stdout) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.remove_all_labels(stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.remove_all_labels(stdout) + db.remove_all_labels(stdout) def install_labels(cls, quiet=True, stdout=None): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_labels(cls, quiet, stdout) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.install_labels(cls, quiet, stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.install_labels(cls, quiet, stdout) + db.install_labels(cls, quiet, stdout) def install_all_labels(stdout=None): deprecated( """ - This method has been moved to the Database singleton (db for sync, adb for async). - Please use adb.install_all_labels(stdout) instead. + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.install_all_labels(stdout) instead. This direct call will be removed in an upcoming version. """ ) - adb.install_all_labels(stdout) + db.install_all_labels(stdout) class TransactionProxy: @@ -1058,10 +1058,10 @@ def build_class_registry(cls): possible_label_combinations.append(base_label_set) for label_set in possible_label_combinations: - if label_set not in adb._NODE_CLASS_REGISTRY: - adb._NODE_CLASS_REGISTRY[label_set] = cls + if label_set not in db._NODE_CLASS_REGISTRY: + db._NODE_CLASS_REGISTRY[label_set] = cls else: - raise NodeClassAlreadyDefined(cls, adb._NODE_CLASS_REGISTRY) + raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY) NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) @@ -1124,7 +1124,7 @@ def element_id(self): if hasattr(self, "element_id_property"): return ( int(self.element_id_property) - if adb.database_version.startswith("4") + if db.database_version.startswith("4") else self.element_id_property ) return None @@ -1181,7 +1181,7 @@ def _build_merge_query( from neomodel.match import _rel_helper query_params["source_id"] = relationship.source.element_id - query = f"MATCH (source:{relationship.source.__label__}) WHERE {adb.get_id_method()}(source) = $source_id\n " + query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " query += "WITH source\n UNWIND $merge_params as params \n " query += "MERGE " query += _rel_helper( @@ -1199,7 +1199,7 @@ def _build_merge_query( # close query if lazy: - query += f"RETURN {adb.get_id_method()}(n)" + query += f"RETURN {db.get_id_method()}(n)" else: query += "RETURN n" @@ -1230,7 +1230,7 @@ def create(cls, *props, **kwargs): # close query if lazy: - query += f" RETURN {adb.get_id_method()}(n)" + query += f" RETURN {db.get_id_method()}(n)" else: query += " RETURN n" @@ -1238,7 +1238,7 @@ def create(cls, *props, **kwargs): for item in [ cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props ]: - node, _ = adb.cypher_query(query, {"create_params": item}) + node, _ = db.cypher_query(query, {"create_params": item}) results.extend(node[0]) nodes = [cls.inflate(node) for node in results] @@ -1294,7 +1294,7 @@ def create_or_update(cls, *props, **kwargs): ) # fetch and build instance for each result - results = adb.cypher_query(query, params) + results = db.cypher_query(query, params) return [cls.inflate(r[0]) for r in results[0]] def cypher(self, query, params=None): @@ -1311,7 +1311,7 @@ def cypher(self, query, params=None): self._pre_action_check("cypher") params = params or {} params.update({"self": self.element_id}) - return adb.cypher_query(query, params) + return db.cypher_query(query, params) @hooks def delete(self): @@ -1322,7 +1322,7 @@ def delete(self): """ self._pre_action_check("delete") self.cypher( - f"MATCH (self) WHERE {adb.get_id_method()}(self)=$self DETACH DELETE self" + f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self" ) delattr(self, "element_id_property") self.deleted = True @@ -1363,7 +1363,7 @@ def get_or_create(cls, *props, **kwargs): ) # fetch and build instance for each result - results = adb.cypher_query(query, params) + results = db.cypher_query(query, params) return [cls.inflate(r[0]) for r in results[0]] @classmethod @@ -1433,7 +1433,7 @@ def labels(self): """ self._pre_action_check("labels") return self.cypher( - f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" + f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" )[0][0][0] def _pre_action_check(self, action): @@ -1453,7 +1453,7 @@ def refresh(self): self._pre_action_check("refresh") if hasattr(self, "element_id"): request = self.cypher( - f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" + f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" )[0] if not request or not request[0]: raise self.__class__.DoesNotExist("Can't refresh non existent node") @@ -1475,7 +1475,7 @@ def save(self): if hasattr(self, "element_id_property"): # update params = self.deflate(self.__properties__, self) - query = f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self\n" + query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n" if params: query += "SET " diff --git a/test/sync/conftest.py b/test/sync/conftest.py index 12e1183f..05001b38 100644 --- a/test/sync/conftest.py +++ b/test/sync/conftest.py @@ -1,17 +1,17 @@ import asyncio import os import warnings -from test._async_compat import mark_async_test +from test._async_compat import mark_sync_test import pytest import pytest_asyncio from neomodel import config -from neomodel._async.core import adb +from neomodel._sync.core import db @pytest_asyncio.fixture(scope="session", autouse=True) -@mark_async_test +@mark_sync_test def setup_neo4j_session(request): """ Provides initial connection to the database and sets up the rest of the test suite @@ -28,7 +28,7 @@ def setup_neo4j_session(request): config.AUTO_INSTALL_LABELS = True # Clear the database if required - database_is_populated, _ = adb.cypher_query( + database_is_populated, _ = db.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) if database_is_populated[0][0] and not request.config.getoption("resetdb"): @@ -36,21 +36,21 @@ def setup_neo4j_session(request): "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." ) - adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + db.clear_neo4j_database(clear_constraints=True, clear_indexes=True) - adb.cypher_query( + db.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) - if adb.database_edition == "enterprise": - adb.cypher_query("GRANT ROLE publisher TO troygreene") - adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + if db.database_edition == "enterprise": + db.cypher_query("GRANT ROLE publisher TO troygreene") + db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") @pytest_asyncio.fixture(scope="session", autouse=True) -@mark_async_test +@mark_sync_test def cleanup(): yield - adb.close_connection() + db.close_connection() @pytest.fixture(scope="session") diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index a304c9fb..0d1da78e 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -1,13 +1,13 @@ import builtins -from test._async_compat import mark_async_test +from test._async_compat import mark_sync_test import pytest from neo4j.exceptions import ClientError as CypherError from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StructuredNode, StringProperty -from neomodel._async.core import adb +from neomodel import StringProperty, StructuredNode +from neomodel._sync.core import db class User2(StructuredNode): @@ -37,7 +37,7 @@ def mocked_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", mocked_import) -@mark_async_test +@mark_sync_test def test_cypher(): """ test result format is backward compatible with earlier versions of neomodel @@ -45,14 +45,14 @@ def test_cypher(): jim = User2(email="jim1@test.com").save() data, meta = jim.cypher( - f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self RETURN a.email" + f"MATCH (a) WHERE {db.get_id_method()}(a)=$self RETURN a.email" ) assert data[0][0] == "jim1@test.com" assert "a.email" in meta data, meta = jim.cypher( f""" - MATCH (a) WHERE {adb.get_id_method()}(a)=$self + MATCH (a) WHERE {db.get_id_method()}(a)=$self MATCH (a)<-[:USER2]-(b) RETURN a, b, 3 """ @@ -60,11 +60,11 @@ def test_cypher(): assert "a" in meta and "b" in meta -@mark_async_test +@mark_sync_test def test_cypher_syntax_error_async(): jim = User2(email="jim1@test.com").save() try: - jim.cypher(f"MATCH a WHERE {adb.get_id_method()}(a)={ self} RETURN xx") + jim.cypher(f"MATCH a WHERE {db.get_id_method()}(a)={ self} RETURN xx") except CypherError as e: assert hasattr(e, "message") assert hasattr(e, "code") @@ -72,7 +72,7 @@ def test_cypher_syntax_error_async(): assert False, "CypherError not raised." -@mark_async_test +@mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) def test_pandas_not_installed_async(hide_available_pkg): with pytest.raises(ImportError): @@ -82,10 +82,10 @@ def test_pandas_not_installed_async(hide_available_pkg): ): from neomodel.integration.pandas import to_dataframe - _ = to_dataframe(adb.cypher_query("MATCH (a) RETURN a.name AS name")) + _ = to_dataframe(db.cypher_query("MATCH (a) RETURN a.name AS name")) -@mark_async_test +@mark_sync_test def test_pandas_integration_async(): from neomodel.integration.pandas import to_dataframe, to_series @@ -94,7 +94,7 @@ def test_pandas_integration_async(): # Test to_dataframe df = to_dataframe( - adb.cypher_query( + db.cypher_query( "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" ) ) @@ -105,7 +105,7 @@ def test_pandas_integration_async(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( - adb.cypher_query( + db.cypher_query( "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" ), index=df["email"], @@ -116,7 +116,7 @@ def test_pandas_integration_async(): # Next test to_series series = to_series( - adb.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") + db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") ) assert isinstance(series, Series) @@ -124,7 +124,7 @@ def test_pandas_integration_async(): assert df["name"].tolist() == ["jimla", "jimlo"] -@mark_async_test +@mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) def test_numpy_not_installed_async(hide_available_pkg): with pytest.raises(ImportError): @@ -134,10 +134,10 @@ def test_numpy_not_installed_async(hide_available_pkg): ): from neomodel.integration.numpy import to_ndarray - _ = to_ndarray(adb.cypher_query("MATCH (a) RETURN a.name AS name")) + _ = to_ndarray(db.cypher_query("MATCH (a) RETURN a.name AS name")) -@mark_async_test +@mark_sync_test def test_numpy_integration_async(): from neomodel.integration.numpy import to_ndarray @@ -145,7 +145,7 @@ def test_numpy_integration_async(): jimlu = UserNP(email="jimlu@test.com", name="jimlu").save() array = to_ndarray( - adb.cypher_query( + db.cypher_query( "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" ) ) From 60581f83e4c0149d3a878a98c04c3f63c28f81b6 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 19 Dec 2023 14:45:01 +0100 Subject: [PATCH 09/73] Some fixes to make-unasync --- bin/make-unasync | 1 + test/_async_compat/__init__.py | 4 ++++ test/_async_compat/mark_decorator.py | 3 +++ test/async_/conftest.py | 13 +++++-------- test/async_/test_cypher.py | 10 +++++----- test/sync/conftest.py | 13 +++++-------- test/sync/test_cypher.py | 10 +++++----- 7 files changed, 28 insertions(+), 26 deletions(-) diff --git a/bin/make-unasync b/bin/make-unasync index c66d8d49..2fabfd24 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -196,6 +196,7 @@ def apply_unasync(files): "_async": "_sync", "adb": "db", "mark_async_test": "mark_sync_test", + "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture", } rules = [ CustomRule( diff --git a/test/_async_compat/__init__.py b/test/_async_compat/__init__.py index d5053965..342678c3 100644 --- a/test/_async_compat/__init__.py +++ b/test/_async_compat/__init__.py @@ -1,7 +1,9 @@ from .mark_decorator import ( AsyncTestDecorators, TestDecorators, + mark_async_session_auto_fixture, mark_async_test, + mark_sync_session_auto_fixture, mark_sync_test, ) @@ -10,4 +12,6 @@ "mark_async_test", "mark_sync_test", "TestDecorators", + "mark_async_session_auto_fixture", + "mark_sync_session_auto_fixture", ] diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py index 1195baa5..a8c5eead 100644 --- a/test/_async_compat/mark_decorator.py +++ b/test/_async_compat/mark_decorator.py @@ -1,6 +1,9 @@ import pytest +import pytest_asyncio mark_async_test = pytest.mark.asyncio +mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True) +mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True) def mark_sync_test(f): diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 65250261..e82da39b 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -1,18 +1,16 @@ import asyncio import os import warnings -from test._async_compat import mark_async_test +from test._async_compat import mark_async_session_auto_fixture import pytest -import pytest_asyncio from neomodel import config from neomodel._async.core import adb -@pytest_asyncio.fixture(scope="session", autouse=True) -@mark_async_test -async def setup_neo4j_session(request): +@mark_async_session_auto_fixture +async def setup_neo4j_session(request, event_loop): """ Provides initial connection to the database and sets up the rest of the test suite @@ -46,9 +44,8 @@ async def setup_neo4j_session(request): await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") -@pytest_asyncio.fixture(scope="session", autouse=True) -@mark_async_test -async def cleanup(): +@mark_async_session_auto_fixture +async def cleanup(event_loop): yield await adb.close_connection() diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index f3cb6d7d..31aa2f68 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -61,7 +61,7 @@ async def test_cypher(): @mark_async_test -async def test_cypher_syntax_error_async(): +async def test_cypher_syntax_error(): jim = await User2(email="jim1@test.com").save() try: await jim.cypher(f"MATCH a WHERE {adb.get_id_method()}(a)={{self}} RETURN xx") @@ -74,7 +74,7 @@ async def test_cypher_syntax_error_async(): @mark_async_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) -async def test_pandas_not_installed_async(hide_available_pkg): +async def test_pandas_not_installed(hide_available_pkg): with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -86,7 +86,7 @@ async def test_pandas_not_installed_async(hide_available_pkg): @mark_async_test -async def test_pandas_integration_async(): +async def test_pandas_integration(): from neomodel.integration.pandas import to_dataframe, to_series jimla = await UserPandas(email="jimla@test.com", name="jimla").save() @@ -126,7 +126,7 @@ async def test_pandas_integration_async(): @mark_async_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) -async def test_numpy_not_installed_async(hide_available_pkg): +async def test_numpy_not_installed(hide_available_pkg): with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -138,7 +138,7 @@ async def test_numpy_not_installed_async(hide_available_pkg): @mark_async_test -async def test_numpy_integration_async(): +async def test_numpy_integration(): from neomodel.integration.numpy import to_ndarray jimly = await UserNP(email="jimly@test.com", name="jimly").save() diff --git a/test/sync/conftest.py b/test/sync/conftest.py index 05001b38..867906b4 100644 --- a/test/sync/conftest.py +++ b/test/sync/conftest.py @@ -1,18 +1,16 @@ import asyncio import os import warnings -from test._async_compat import mark_sync_test +from test._async_compat import mark_sync_session_auto_fixture import pytest -import pytest_asyncio from neomodel import config from neomodel._sync.core import db -@pytest_asyncio.fixture(scope="session", autouse=True) -@mark_sync_test -def setup_neo4j_session(request): +@mark_sync_session_auto_fixture +def setup_neo4j_session(request, event_loop): """ Provides initial connection to the database and sets up the rest of the test suite @@ -46,9 +44,8 @@ def setup_neo4j_session(request): db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") -@pytest_asyncio.fixture(scope="session", autouse=True) -@mark_sync_test -def cleanup(): +@mark_sync_session_auto_fixture +def cleanup(event_loop): yield db.close_connection() diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 0d1da78e..44c1e859 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -61,7 +61,7 @@ def test_cypher(): @mark_sync_test -def test_cypher_syntax_error_async(): +def test_cypher_syntax_error(): jim = User2(email="jim1@test.com").save() try: jim.cypher(f"MATCH a WHERE {db.get_id_method()}(a)={ self} RETURN xx") @@ -74,7 +74,7 @@ def test_cypher_syntax_error_async(): @mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) -def test_pandas_not_installed_async(hide_available_pkg): +def test_pandas_not_installed(hide_available_pkg): with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -86,7 +86,7 @@ def test_pandas_not_installed_async(hide_available_pkg): @mark_sync_test -def test_pandas_integration_async(): +def test_pandas_integration(): from neomodel.integration.pandas import to_dataframe, to_series jimla = UserPandas(email="jimla@test.com", name="jimla").save() @@ -126,7 +126,7 @@ def test_pandas_integration_async(): @mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) -def test_numpy_not_installed_async(hide_available_pkg): +def test_numpy_not_installed(hide_available_pkg): with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -138,7 +138,7 @@ def test_numpy_not_installed_async(hide_available_pkg): @mark_sync_test -def test_numpy_integration_async(): +def test_numpy_integration(): from neomodel.integration.numpy import to_ndarray jimly = UserNP(email="jimly@test.com", name="jimly").save() From ebe0cb7247bda0e627c32e305234101c25c38cb0 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 19 Dec 2023 15:51:40 +0100 Subject: [PATCH 10/73] Run mak-unasync with python 3.11 --- neomodel/_sync/core.py | 2 +- test/sync/test_cypher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py index 3c7897c2..a46867f6 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/_sync/core.py @@ -1162,7 +1162,7 @@ def _build_merge_query( for p in cls.__required_properties__ ) ) - n_merge = f"n:{n_merge_labels} { {n_merge_prm}} " + n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" if relationship is None: # create "simple" unwind query query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 44c1e859..3abe1981 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -64,7 +64,7 @@ def test_cypher(): def test_cypher_syntax_error(): jim = User2(email="jim1@test.com").save() try: - jim.cypher(f"MATCH a WHERE {db.get_id_method()}(a)={ self} RETURN xx") + jim.cypher(f"MATCH a WHERE {db.get_id_method()}(a)={{self}} RETURN xx") except CypherError as e: assert hasattr(e, "message") assert hasattr(e, "code") From 34d28474fb8f5a19a5f7ab51e295457e96cc1242 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 19 Dec 2023 16:00:59 +0100 Subject: [PATCH 11/73] Fix scripts - sync code --- neomodel/scripts/neomodel_inspect_database.py | 32 +++++++++---------- neomodel/scripts/neomodel_install_labels.py | 6 ++-- neomodel/scripts/neomodel_remove_labels.py | 6 ++-- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index 6231066e..35b99d42 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -33,7 +33,7 @@ import textwrap from os import environ -from neomodel._async.core import adb +from neomodel._sync.core import db IMPORTS = [] @@ -80,13 +80,13 @@ def get_properties_for_label(label): ORDER BY size(properties) DESC RETURN apoc.meta.cypher.types(properties(sampleNode)) AS properties LIMIT 1 """ - result, _ = adb.cypher_query(query) + result, _ = db.cypher_query(query) if result is not None and len(result) > 0: return result[0][0] @staticmethod def get_constraints_for_label(label): - constraints, meta_constraints = adb.cypher_query( + constraints, meta_constraints = db.cypher_query( f"SHOW CONSTRAINTS WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='UNIQUENESS'" ) constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] @@ -99,12 +99,12 @@ def get_constraints_for_label(label): @staticmethod def get_indexed_properties_for_label(label): - if adb.version_is_higher_than("5.0"): - indexes, meta_indexes = adb.cypher_query( + if db.version_is_higher_than("5.0"): + indexes, meta_indexes = db.cypher_query( f"SHOW INDEXES WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='RANGE' AND owningConstraint IS NULL" ) else: - indexes, meta_indexes = adb.cypher_query( + indexes, meta_indexes = db.cypher_query( f"SHOW INDEXES WHERE entityType='NODE' AND '{label}' IN labelsOrTypes AND type='BTREE' AND uniqueness='NONUNIQUE'" ) indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] @@ -132,12 +132,12 @@ def outgoing_relationships(cls, start_label, get_properties: bool = True): WITH DISTINCT type(r) as rel_type, head(labels(m)) AS target_label RETURN rel_type, target_label, {{}} AS properties LIMIT 1 """ - result, _ = adb.cypher_query(query) + result, _ = db.cypher_query(query) return [(record[0], record[1], record[2]) for record in result] @staticmethod def get_constraints_for_type(rel_type): - constraints, meta_constraints = adb.cypher_query( + constraints, meta_constraints = db.cypher_query( f"SHOW CONSTRAINTS WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='RELATIONSHIP_UNIQUENESS'" ) constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] @@ -150,12 +150,12 @@ def get_constraints_for_type(rel_type): @staticmethod def get_indexed_properties_for_type(rel_type): - if adb.version_is_higher_than("5.0"): - indexes, meta_indexes = adb.cypher_query( + if db.version_is_higher_than("5.0"): + indexes, meta_indexes = db.cypher_query( f"SHOW INDEXES WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='RANGE' AND owningConstraint IS NULL" ) else: - indexes, meta_indexes = adb.cypher_query( + indexes, meta_indexes = db.cypher_query( f"SHOW INDEXES WHERE entityType='RELATIONSHIP' AND '{rel_type}' IN labelsOrTypes AND type='BTREE' AND uniqueness='NONUNIQUE'" ) indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] @@ -169,7 +169,7 @@ def get_indexed_properties_for_type(rel_type): @staticmethod def infer_cardinality(rel_type, start_label): range_start_query = f"MATCH (n:`{start_label}`) WHERE NOT EXISTS ((n)-[:`{rel_type}`]->()) WITH n LIMIT 1 RETURN count(n)" - result, _ = adb.cypher_query(range_start_query) + result, _ = db.cypher_query(range_start_query) is_start_zero = result[0][0] > 0 range_end_query = f""" @@ -179,7 +179,7 @@ def infer_cardinality(rel_type, start_label): WITH n LIMIT 1 RETURN count(n) """ - result, _ = adb.cypher_query(range_end_query) + result, _ = db.cypher_query(range_end_query) is_end_one = result[0][0] == 0 cardinality = "Zero" if is_start_zero else "One" @@ -193,7 +193,7 @@ def infer_cardinality(rel_type, start_label): def get_node_labels(): query = "CALL db.labels()" - result, _ = adb.cypher_query(query) + result, _ = db.cypher_query(query) return [record[0] for record in result] @@ -245,7 +245,7 @@ def build_rel_type_definition( unique_properties = ( RelationshipInspector.get_constraints_for_type(rel_type) - if adb.version_is_higher_than("5.7") + if db.version_is_higher_than("5.7") else [] ) indexed_properties = RelationshipInspector.get_indexed_properties_for_type( @@ -286,7 +286,7 @@ def inspect_database( ): # Connect to the database print(f"Connecting to {bolt_url}") - adb.set_connection(bolt_url) + db.set_connection(bolt_url) node_labels = get_node_labels() defined_rel_types = [] diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index 39a59c77..c0d9c82a 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -32,7 +32,7 @@ from importlib import import_module from os import environ, path -from neomodel._async.core import adb +from neomodel._sync.core import db def load_python_module_or_file(name): @@ -109,9 +109,9 @@ def main(): # Connect after to override any code in the module that may set the connection print(f"Connecting to {bolt_url}") - adb.set_connection(url=bolt_url) + db.set_connection(url=bolt_url) - adb.install_all_labels() + db.install_all_labels() if __name__ == "__main__": diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 25cf25bc..2272c7fa 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -27,7 +27,7 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from os import environ -from neomodel._async.core import adb +from neomodel._sync.core import db def main(): @@ -61,9 +61,9 @@ def main(): # Connect after to override any code in the module that may set the connection print(f"Connecting to {bolt_url}") - adb.set_connection(url=bolt_url) + db.set_connection(url=bolt_url) - adb.remove_all_labels() + db.remove_all_labels() if __name__ == "__main__": From eff09280b42220c43576591976f173d29911685f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 19 Dec 2023 16:56:11 +0100 Subject: [PATCH 12/73] Fix unasync script for Python 3.12 --- bin/make-unasync | 16 ++++++++++++++++ test/sync/test_cypher.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/bin/make-unasync b/bin/make-unasync index 2fabfd24..e390c186 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -64,6 +64,16 @@ def _tokenize(f): last_end = tok.end if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]: last_end = (tok.end[0] + 1, 0) + elif ( + sys.version_info >= (3, 12) + and tok.type == std_tokenize.FSTRING_MIDDLE + ): + last_end = ( + last_end[0], + last_end[1] + + tok.string.count("{") + + tok.string.count("}") + ) def _untokenize(tokens): @@ -110,6 +120,12 @@ class CustomRule(unasync.Rule): tokval[-1:], ) tokval = left_quote + self._unasync_string(name) + right_quote + elif ( + sys.version_info >= (3, 12) + and toknum == std_tokenize.FSTRING_MIDDLE + ): + tokval = tokval.replace("{", "{{").replace("}", "}}") + tokval = self._unasync_string(tokval) if used_space is None: used_space = space yield (used_space, tokval) diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 3abe1981..a76e3ba8 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -6,7 +6,7 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNode +from neomodel import StructuredNode, StringProperty from neomodel._sync.core import db From 3b574e53e951612dcf804f1d9d73844a8ec41d09 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 20 Dec 2023 10:47:28 +0100 Subject: [PATCH 13/73] Migrate more to async --- bin/make-unasync | 11 +- neomodel/__init__.py | 35 +- neomodel/_async/cardinality.py | 135 +++ neomodel/_async/core.py | 10 +- neomodel/_async/match.py | 1031 ++++++++++++++++++ neomodel/{ => _async}/path.py | 7 +- neomodel/{ => _async}/relationship.py | 16 +- neomodel/_async/relationship_manager.py | 542 +++++++++ neomodel/{ => _sync}/cardinality.py | 4 +- neomodel/_sync/core.py | 18 +- neomodel/{ => _sync}/match.py | 53 +- neomodel/_sync/path.py | 53 + neomodel/_sync/relationship.py | 171 +++ neomodel/{ => _sync}/relationship_manager.py | 40 +- neomodel/properties.py | 6 +- neomodel/util.py | 2 + test/test_cardinality.py | 14 +- test/test_database_management.py | 4 +- test/test_issue283.py | 4 +- test/test_issue600.py | 2 +- test/test_label_install.py | 10 +- test/test_match_api.py | 57 +- test/test_migration_neo4j_5.py | 4 +- test/test_models.py | 10 +- test/test_paths.py | 10 +- test/test_relationship_models.py | 4 +- test/test_relationships.py | 10 +- test/test_scripts.py | 4 +- 28 files changed, 2102 insertions(+), 165 deletions(-) create mode 100644 neomodel/_async/cardinality.py create mode 100644 neomodel/_async/match.py rename neomodel/{ => _async}/path.py (89%) rename neomodel/{ => _async}/relationship.py (95%) create mode 100644 neomodel/_async/relationship_manager.py rename neomodel/{ => _sync}/cardinality.py (97%) rename neomodel/{ => _sync}/match.py (95%) create mode 100644 neomodel/_sync/path.py create mode 100644 neomodel/_sync/relationship.py rename neomodel/{ => _sync}/relationship_manager.py (92%) diff --git a/bin/make-unasync b/bin/make-unasync index e390c186..33ccbf90 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -64,15 +64,10 @@ def _tokenize(f): last_end = tok.end if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]: last_end = (tok.end[0] + 1, 0) - elif ( - sys.version_info >= (3, 12) - and tok.type == std_tokenize.FSTRING_MIDDLE - ): + elif sys.version_info >= (3, 12) and tok.type == std_tokenize.FSTRING_MIDDLE: last_end = ( last_end[0], - last_end[1] - + tok.string.count("{") - + tok.string.count("}") + last_end[1] + tok.string.count("{") + tok.string.count("}"), ) @@ -207,7 +202,7 @@ class CustomRule(unasync.Rule): def apply_unasync(files): """Generate sync code from async code.""" - additional_main_replacements = {"adb": "db"} + additional_main_replacements = {"adb": "db", "_async": "_sync"} additional_test_replacements = { "_async": "_sync", "adb": "db", diff --git a/neomodel/__init__.py b/neomodel/__init__.py index e6380f67..0997322d 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,5 +1,11 @@ # pep8: noqa # TODO : Check imports here +from neomodel._async.cardinality import ( + AsyncOne, + AsyncOneOrMore, + AsyncZeroOrMore, + AsyncZeroOrOne, +) from neomodel._async.core import ( AsyncStructuredNode, change_neo4j_password, @@ -10,12 +16,25 @@ install_labels, remove_all_labels, ) +from neomodel._async.match import AsyncNodeSet, AsyncTraversal +from neomodel._async.path import AsyncNeomodelPath +from neomodel._async.relationship import AsyncStructuredRel +from neomodel._async.relationship_manager import ( + AsyncRelationshipManager, + NotConnected, + Relationship, + RelationshipDefinition, + RelationshipFrom, + RelationshipTo, +) +from neomodel._sync.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne from neomodel._sync.core import StructuredNode -from neomodel.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne +from neomodel._sync.match import NodeSet, Traversal +from neomodel._sync.path import NeomodelPath +from neomodel._sync.relationship import StructuredRel +from neomodel._sync.relationship_manager import RelationshipManager from neomodel.exceptions import * -from neomodel.match import EITHER, INCOMING, OUTGOING, NodeSet, Traversal from neomodel.match_q import Q # noqa -from neomodel.path import NeomodelPath from neomodel.properties import ( AliasProperty, ArrayProperty, @@ -32,15 +51,7 @@ StringProperty, UniqueIdProperty, ) -from neomodel.relationship import StructuredRel -from neomodel.relationship_manager import ( - NotConnected, - Relationship, - RelationshipDefinition, - RelationshipFrom, - RelationshipManager, - RelationshipTo, -) +from neomodel.util import EITHER, INCOMING, OUTGOING __author__ = "Robin Edwards" __email__ = "robin.ge@gmail.com" diff --git a/neomodel/_async/cardinality.py b/neomodel/_async/cardinality.py new file mode 100644 index 00000000..7b1f5cf0 --- /dev/null +++ b/neomodel/_async/cardinality.py @@ -0,0 +1,135 @@ +from neomodel._async.relationship_manager import ( # pylint:disable=unused-import + AsyncRelationshipManager, + AsyncZeroOrMore, +) +from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation + + +class AsyncZeroOrOne(AsyncRelationshipManager): + """A relationship to zero or one node.""" + + description = "zero or one relationship" + + def single(self): + """ + Return the associated node. + + :return: node + """ + nodes = super().all() + if len(nodes) == 1: + return nodes[0] + if len(nodes) > 1: + raise CardinalityViolation(self, len(nodes)) + return None + + def all(self): + node = self.single() + return [node] if node else [] + + async def connect(self, node, properties=None): + """ + Connect to a node. + + :param node: + :type: StructuredNode + :param properties: relationship properties + :type: dict + :return: True / rel instance + """ + if len(self): + raise AttemptedCardinalityViolation( + f"Node already has {self} can't connect more" + ) + return await super().connect(node, properties) + + +class AsyncOneOrMore(AsyncRelationshipManager): + """A relationship to zero or more nodes.""" + + description = "one or more relationships" + + def single(self): + """ + Fetch one of the related nodes + + :return: Node + """ + nodes = super().all() + if nodes: + return nodes[0] + raise CardinalityViolation(self, "none") + + def all(self): + """ + Returns all related nodes. + + :return: [node1, node2...] + """ + nodes = super().all() + if nodes: + return nodes + raise CardinalityViolation(self, "none") + + async def disconnect(self, node): + """ + Disconnect node + :param node: + :return: + """ + if super().__len__() < 2: + raise AttemptedCardinalityViolation("One or more expected") + return await super().disconnect(node) + + +class AsyncOne(AsyncRelationshipManager): + """ + A relationship to a single node + """ + + description = "one relationship" + + def single(self): + """ + Return the associated node. + + :return: node + """ + nodes = super().all() + if nodes: + if len(nodes) == 1: + return nodes[0] + raise CardinalityViolation(self, len(nodes)) + raise CardinalityViolation(self, "none") + + def all(self): + """ + Return single node in an array + + :return: [node] + """ + return [self.single()] + + async def disconnect(self, node): + raise AttemptedCardinalityViolation( + "Cardinality one, cannot disconnect use reconnect." + ) + + async def disconnect_all(self): + raise AttemptedCardinalityViolation( + "Cardinality one, cannot disconnect_all use reconnect." + ) + + async def connect(self, node, properties=None): + """ + Connect a node + + :param node: + :param properties: relationship properties + :return: True / rel instance + """ + if not hasattr(self.source, "element_id") or self.source.element_id is None: + raise ValueError("Node has not been saved cannot connect!") + if len(self): + raise AttemptedCardinalityViolation("Node already has one relationship") + return await super().connect(node, properties) diff --git a/neomodel/_async/core.py b/neomodel/_async/core.py index e2c07e66..348f23a7 100644 --- a/neomodel/_async/core.py +++ b/neomodel/_async/core.py @@ -344,9 +344,9 @@ def _object_resolution(self, object_to_resolve): return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) if isinstance(object_to_resolve, Path): - from neomodel.path import NeomodelPath + from neomodel._async.path import AsyncNeomodelPath - return NeomodelPath(object_to_resolve) + return AsyncNeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): return self._result_resolution([object_to_resolve]) @@ -1115,9 +1115,9 @@ def nodes(cls): :return: NodeSet :rtype: NodeSet """ - from neomodel.match import NodeSet + from neomodel._async.match import AsyncNodeSet - return NodeSet(cls) + return AsyncNodeSet(cls) @property def element_id(self): @@ -1178,7 +1178,7 @@ def _build_merge_query( "No relation_type is specified on provided relationship" ) - from neomodel.match import _rel_helper + from neomodel._async.match import _rel_helper query_params["source_id"] = relationship.source.element_id query = f"MATCH (source:{relationship.source.__label__}) WHERE {adb.get_id_method()}(source) = $source_id\n " diff --git a/neomodel/_async/match.py b/neomodel/_async/match.py new file mode 100644 index 00000000..71d5a653 --- /dev/null +++ b/neomodel/_async/match.py @@ -0,0 +1,1031 @@ +import inspect +import re +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from neomodel._async.core import AsyncStructuredNode, adb +from neomodel.exceptions import MultipleNodesReturned +from neomodel.match_q import Q, QBase +from neomodel.properties import AliasProperty +from neomodel.util import INCOMING, OUTGOING + + +def _rel_helper( + lhs, + rhs, + ident=None, + relation_type=None, + direction=None, + relation_properties=None, + **kwargs, # NOSONAR +): + """ + Generate a relationship matching string, with specified parameters. + Examples: + relation_direction = OUTGOING: (lhs)-[relation_ident:relation_type]->(rhs) + relation_direction = INCOMING: (lhs)<-[relation_ident:relation_type]-(rhs) + relation_direction = EITHER: (lhs)-[relation_ident:relation_type]-(rhs) + + :param lhs: The left hand statement. + :type lhs: str + :param rhs: The right hand statement. + :type rhs: str + :param ident: A specific identity to name the relationship, or None. + :type ident: str + :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel. + :type relation_type: str + :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING. + :param relation_properties: dictionary of relationship properties to match + :returns: string + """ + rel_props = "" + + if relation_properties: + rel_props_str = ", ".join( + (f"{key}: {value}" for key, value in relation_properties.items()) + ) + rel_props = f" {{{rel_props_str}}}" + + rel_def = "" + # relation_type is unspecified + if relation_type is None: + rel_def = "" + # all("*" wildcard) relation_type + elif relation_type == "*": + rel_def = "[*]" + else: + # explicit relation_type + rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]" + + stmt = "" + if direction == OUTGOING: + stmt = f"-{rel_def}->" + elif direction == INCOMING: + stmt = f"<-{rel_def}-" + else: + stmt = f"-{rel_def}-" + + # Make sure not to add parenthesis when they are already present + if lhs[-1] != ")": + lhs = f"({lhs})" + if rhs[-1] != ")": + rhs = f"({rhs})" + + return f"{lhs}{stmt}{rhs}" + + +def _rel_merge_helper( + lhs, + rhs, + ident="neomodelident", + relation_type=None, + direction=None, + relation_properties=None, + **kwargs, # NOSONAR +): + """ + Generate a relationship merging string, with specified parameters. + Examples: + relation_direction = OUTGOING: (lhs)-[relation_ident:relation_type]->(rhs) + relation_direction = INCOMING: (lhs)<-[relation_ident:relation_type]-(rhs) + relation_direction = EITHER: (lhs)-[relation_ident:relation_type]-(rhs) + + :param lhs: The left hand statement. + :type lhs: str + :param rhs: The right hand statement. + :type rhs: str + :param ident: A specific identity to name the relationship, or None. + :type ident: str + :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel. + :type relation_type: str + :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING. + :param relation_properties: dictionary of relationship properties to merge + :returns: string + """ + + if direction == OUTGOING: + stmt = "-{0}->" + elif direction == INCOMING: + stmt = "<-{0}-" + else: + stmt = "-{0}-" + + rel_props = "" + rel_none_props = "" + + if relation_properties: + rel_props_str = ", ".join( + ( + f"{key}: {value}" + for key, value in relation_properties.items() + if value is not None + ) + ) + rel_props = f" {{{rel_props_str}}}" + if None in relation_properties.values(): + rel_prop_val_str = ", ".join( + ( + f"{ident}.{key}=${key!s}" + for key, value in relation_properties.items() + if value is None + ) + ) + rel_none_props = ( + f" ON CREATE SET {rel_prop_val_str} ON MATCH SET {rel_prop_val_str}" + ) + # relation_type is unspecified + if relation_type is None: + stmt = stmt.format("") + # all("*" wildcard) relation_type + elif relation_type == "*": + stmt = stmt.format("[*]") + else: + # explicit relation_type + stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]") + + return f"({lhs}){stmt}({rhs}){rel_none_props}" + + +# special operators +_SPECIAL_OPERATOR_IN = "IN" +_SPECIAL_OPERATOR_INSENSITIVE = "(?i)" +_SPECIAL_OPERATOR_ISNULL = "IS NULL" +_SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL" +_SPECIAL_OPERATOR_REGEX = "=~" + +_UNARY_OPERATORS = (_SPECIAL_OPERATOR_ISNULL, _SPECIAL_OPERATOR_ISNOTNULL) + +_REGEX_INSESITIVE = _SPECIAL_OPERATOR_INSENSITIVE + "{}" +_REGEX_CONTAINS = ".*{}.*" +_REGEX_STARTSWITH = "{}.*" +_REGEX_ENDSWITH = ".*{}" + +# regex operations that require escaping +_STRING_REGEX_OPERATOR_TABLE = { + "iexact": _REGEX_INSESITIVE, + "contains": _REGEX_CONTAINS, + "icontains": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_CONTAINS, + "startswith": _REGEX_STARTSWITH, + "istartswith": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_STARTSWITH, + "endswith": _REGEX_ENDSWITH, + "iendswith": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_ENDSWITH, +} +# regex operations that do not require escaping +_REGEX_OPERATOR_TABLE = { + "iregex": _REGEX_INSESITIVE, +} +# list all regex operations, these will require formatting of the value +_REGEX_OPERATOR_TABLE.update(_STRING_REGEX_OPERATOR_TABLE) + +# list all supported operators +OPERATOR_TABLE = { + "lt": "<", + "gt": ">", + "lte": "<=", + "gte": ">=", + "ne": "<>", + "in": _SPECIAL_OPERATOR_IN, + "isnull": _SPECIAL_OPERATOR_ISNULL, + "regex": _SPECIAL_OPERATOR_REGEX, + "exact": "=", +} +# add all regex operators +OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) + + +def install_traversals(cls, node_set): + """ + For a StructuredNode class install Traversal objects for each + relationship definition on a NodeSet instance + """ + rels = cls.defined_properties(rels=True, aliases=False, properties=False) + + for key in rels.keys(): + if hasattr(node_set, key): + raise ValueError(f"Cannot install traversal '{key}' exists on NodeSet") + + rel = getattr(cls, key) + rel.lookup_node_class() + + traversal = AsyncTraversal(source=node_set, name=key, definition=rel.definition) + setattr(node_set, key, traversal) + + +def process_filter_args(cls, kwargs): + """ + loop through properties in filter parameters check they match class definition + deflate them and convert into something easy to generate cypher from + """ + + output = {} + + for key, value in kwargs.items(): + if "__" in key: + prop, operator = key.rsplit("__") + operator = OPERATOR_TABLE[operator] + else: + prop = key + operator = "=" + + if prop not in cls.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + + property_obj = getattr(cls, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + deflated_value = getattr(cls, prop).deflate(value) + else: + operator, deflated_value = transform_operator_to_filter( + operator=operator, + filter_key=key, + filter_value=value, + property_obj=property_obj, + ) + + # map property to correct property name in the database + db_property = cls.defined_properties(rels=False)[prop].db_property or prop + + output[db_property] = (operator, deflated_value) + + return output + + +def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): + # handle special operators + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): + raise ValueError( + f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" + ) + deflated_value = [property_obj.deflate(v) for v in filter_value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(filter_value, bool): + raise ValueError( + f"Value must be a bool for isnull operation on {filter_key}" + ) + operator = "IS NULL" if filter_value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(filter_value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {filter_key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(filter_value) + + return operator, deflated_value + + +def process_has_args(cls, kwargs): + """ + loop through has parameters check they correspond to class rels defined + """ + rel_definitions = cls.defined_properties(properties=False, rels=True, aliases=False) + + match, dont_match = {}, {} + + for key, value in kwargs.items(): + if key not in rel_definitions: + raise ValueError(f"No such relation {key} defined on a {cls.__name__}") + + rhs_ident = key + + rel_definitions[key].lookup_node_class() + + if value is True: + match[rhs_ident] = rel_definitions[key].definition + elif value is False: + dont_match[rhs_ident] = rel_definitions[key].definition + elif isinstance(value, AsyncNodeSet): + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("Expecting True / False / NodeSet got: " + repr(value)) + + return match, dont_match + + +class QueryAST: + match: Optional[list] + optional_match: Optional[list] + where: Optional[list] + with_clause: Optional[str] + return_clause: Optional[str] + order_by: Optional[str] + skip: Optional[int] + limit: Optional[int] + result_class: Optional[type] + lookup: Optional[str] + additional_return: Optional[list] + is_count: Optional[bool] + + def __init__( + self, + match: Optional[list] = None, + optional_match: Optional[list] = None, + where: Optional[list] = None, + with_clause: Optional[str] = None, + return_clause: Optional[str] = None, + order_by: Optional[str] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + result_class: Optional[type] = None, + lookup: Optional[str] = None, + additional_return: Optional[list] = None, + is_count: Optional[bool] = False, + ): + self.match = match if match else [] + self.optional_match = optional_match if optional_match else [] + self.where = where if where else [] + self.with_clause = with_clause + self.return_clause = return_clause + self.order_by = order_by + self.skip = skip + self.limit = limit + self.result_class = result_class + self.lookup = lookup + self.additional_return = additional_return if additional_return else [] + self.is_count = is_count + + +class AsyncQueryBuilder: + def __init__(self, node_set): + self.node_set = node_set + self._ast = QueryAST() + self._query_params = {} + self._place_holder_registry = {} + self._ident_count = 0 + self._node_counters = defaultdict(int) + + def build_ast(self): + if hasattr(self.node_set, "relations_to_fetch"): + for relation in self.node_set.relations_to_fetch: + self.build_traversal_from_path(relation, self.node_set.source) + + self.build_source(self.node_set) + + if hasattr(self.node_set, "skip"): + self._ast.skip = self.node_set.skip + if hasattr(self.node_set, "limit"): + self._ast.limit = self.node_set.limit + + return self + + def build_source(self, source): + if isinstance(source, AsyncTraversal): + return self.build_traversal(source) + if isinstance(source, AsyncNodeSet): + if inspect.isclass(source.source) and issubclass( + source.source, AsyncStructuredNode + ): + ident = self.build_label(source.source.__label__.lower(), source.source) + else: + ident = self.build_source(source.source) + + self.build_additional_match(ident, source) + + if hasattr(source, "order_by_elements"): + self.build_order_by(ident, source) + + if source.filters or source.q_filters: + self.build_where_stmt( + ident, + source.filters, + source.q_filters, + source_class=source.source_class, + ) + + return ident + if isinstance(source, AsyncStructuredNode): + return self.build_node(source) + raise ValueError("Unknown source type " + repr(source)) + + def create_ident(self): + self._ident_count += 1 + return "r" + str(self._ident_count) + + def build_order_by(self, ident, source): + if "?" in source.order_by_elements: + self._ast.with_clause = f"{ident}, rand() as r" + self._ast.order_by = "r" + else: + self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] + + def build_traversal(self, traversal): + """ + traverse a relationship from a node to a set of nodes + """ + # build source + rhs_label = ":" + traversal.target_class.__label__ + + # build source + rel_ident = self.create_ident() + lhs_ident = self.build_source(traversal.source) + traversal_ident = f"{traversal.name}_{rel_ident}" + rhs_ident = traversal_ident + rhs_label + self._ast.return_clause = traversal_ident + self._ast.result_class = traversal.target_class + + stmt = _rel_helper( + lhs=lhs_ident, + rhs=rhs_ident, + ident=rel_ident, + **traversal.definition, + ) + self._ast.match.append(stmt) + + if traversal.filters: + self.build_where_stmt(rel_ident, traversal.filters) + + return traversal_ident + + def _additional_return(self, name): + if name not in self._ast.additional_return and name != self._ast.return_clause: + self._ast.additional_return.append(name) + + def build_traversal_from_path(self, relation: dict, source_class) -> str: + path: str = relation["path"] + stmt: str = "" + source_class_iterator = source_class + for index, part in enumerate(path.split("__")): + relationship = getattr(source_class_iterator, part) + # build source + if "node_class" not in relationship.definition: + relationship.lookup_node_class() + rhs_label = relationship.definition["node_class"].__label__ + rel_reference = f'{relationship.definition["node_class"]}_{part}' + self._node_counters[rel_reference] += 1 + rhs_name = ( + f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + ) + rhs_ident = f"{rhs_name}:{rhs_label}" + self._additional_return(rhs_name) + if not stmt: + lhs_label = source_class_iterator.__label__ + lhs_name = lhs_label.lower() + lhs_ident = f"{lhs_name}:{lhs_label}" + if not index: + # This is the first one, we make sure that 'return' + # contains the primary node so _contains() works + # as usual + self._ast.return_clause = lhs_name + else: + self._additional_return(lhs_name) + else: + lhs_ident = stmt + + rel_ident = self.create_ident() + self._additional_return(rel_ident) + stmt = _rel_helper( + lhs=lhs_ident, + rhs=rhs_ident, + ident=rel_ident, + direction=relationship.definition["direction"], + relation_type=relationship.definition["relation_type"], + ) + source_class_iterator = relationship.definition["node_class"] + + if relation.get("optional"): + self._ast.optional_match.append(stmt) + else: + self._ast.match.append(stmt) + return rhs_name + + def build_node(self, node): + ident = node.__class__.__name__.lower() + place_holder = self._register_place_holder(ident) + + # Hack to emulate START to lookup a node by id + _node_lookup = f"MATCH ({ident}) WHERE {adb.get_id_method()}({ident})=${place_holder} WITH {ident}" + self._ast.lookup = _node_lookup + + self._query_params[place_holder] = node.element_id + + self._ast.return_clause = ident + self._ast.result_class = node.__class__ + return ident + + def build_label(self, ident, cls): + """ + match nodes by a label + """ + ident_w_label = ident + ":" + cls.__label__ + + if not self._ast.return_clause and ( + not self._ast.additional_return or ident not in self._ast.additional_return + ): + self._ast.match.append(f"({ident_w_label})") + self._ast.return_clause = ident + self._ast.result_class = cls + return ident + + def build_additional_match(self, ident, node_set): + """ + handle additional matches supplied by 'has()' calls + """ + source_ident = ident + + for _, value in node_set.must_match.items(): + if isinstance(value, dict): + label = ":" + value["node_class"].__label__ + stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **value) + self._ast.where.append(stmt) + else: + raise ValueError("Expecting dict got: " + repr(value)) + + for _, val in node_set.dont_match.items(): + if isinstance(val, dict): + label = ":" + val["node_class"].__label__ + stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **val) + self._ast.where.append("NOT " + stmt) + else: + raise ValueError("Expecting dict got: " + repr(val)) + + def _register_place_holder(self, key): + if key in self._place_holder_registry: + self._place_holder_registry[key] += 1 + else: + self._place_holder_registry[key] = 1 + return key + "_" + str(self._place_holder_registry[key]) + + def _parse_q_filters(self, ident, q, source_class): + target = [] + for child in q.children: + if isinstance(child, QBase): + q_childs = self._parse_q_filters(ident, child, source_class) + if child.connector == Q.OR: + q_childs = "(" + q_childs + ")" + target.append(q_childs) + else: + kwargs = {child[0]: child[1]} + filters = process_filter_args(source_class, kwargs) + for prop, op_and_val in filters.items(): + operator, val = op_and_val + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = f"{ident}.{prop} {operator}" + else: + place_holder = self._register_place_holder(ident + "_" + prop) + statement = f"{ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + target.append(statement) + ret = f" {q.connector} ".join(target) + if q.negated: + ret = f"NOT ({ret})" + return ret + + def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): + """ + construct a where statement from some filters + """ + if q_filters is not None: + stmts = self._parse_q_filters(ident, q_filters, source_class) + if stmts: + self._ast.where.append(stmts) + else: + stmts = [] + for row in filters: + negate = False + + # pre-process NOT cases as they are nested dicts + if "__NOT__" in row and len(row) == 1: + negate = True + row = row["__NOT__"] + + for prop, operator_and_val in row.items(): + operator, val = operator_and_val + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = ( + f"{'NOT' if negate else ''} {ident}.{prop} {operator}" + ) + else: + place_holder = self._register_place_holder(ident + "_" + prop) + statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + stmts.append(statement) + + self._ast.where.append(" AND ".join(stmts)) + + def build_query(self): + query = "" + + if self._ast.lookup: + query += self._ast.lookup + + # Instead of using only one MATCH statement for every relation + # to follow, we use one MATCH per relation (to avoid cartesian + # product issues...). + # There might be optimizations to be done, using projections, + # or pusing patterns instead of a chain of OPTIONAL MATCH. + if self._ast.match: + query += " MATCH " + query += " MATCH ".join(i for i in self._ast.match) + + if self._ast.optional_match: + query += " OPTIONAL MATCH " + query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match) + + if self._ast.where: + query += " WHERE " + query += " AND ".join(self._ast.where) + + if self._ast.with_clause: + query += " WITH " + query += self._ast.with_clause + + query += " RETURN " + if self._ast.return_clause: + query += self._ast.return_clause + if self._ast.additional_return: + if self._ast.return_clause: + query += ", " + query += ", ".join(self._ast.additional_return) + + if self._ast.order_by: + query += " ORDER BY " + query += ", ".join(self._ast.order_by) + + # If we return a count with pagination, pagination has to happen before RETURN + # It will then be included in the WITH clause already + if self._ast.skip and not self._ast.is_count: + query += f" SKIP {self._ast.skip}" + + if self._ast.limit and not self._ast.is_count: + query += f" LIMIT {self._ast.limit}" + + return query + + async def _count(self): + self._ast.is_count = True + # If we return a count with pagination, pagination has to happen before RETURN + # Like : WITH my_var SKIP 10 LIMIT 10 RETURN count(my_var) + self._ast.with_clause = f"{self._ast.return_clause}" + if self._ast.skip: + self._ast.with_clause += f" SKIP {self._ast.skip}" + + if self._ast.limit: + self._ast.with_clause += f" LIMIT {self._ast.limit}" + + self._ast.return_clause = f"count({self._ast.return_clause})" + # drop order_by, results in an invalid query + self._ast.order_by = None + # drop additional_return to avoid unexpected result + self._ast.additional_return = None + query = self.build_query() + results, _ = await adb.cypher_query(query, self._query_params) + return int(results[0][0]) + + def _contains(self, node_element_id): + # inject id = into ast + if not self._ast.return_clause: + print(self._ast.additional_return) + self._ast.return_clause = self._ast.additional_return[0] + ident = self._ast.return_clause + place_holder = self._register_place_holder(ident + "_contains") + self._ast.where.append(f"{adb.get_id_method()}({ident}) = ${place_holder}") + self._query_params[place_holder] = node_element_id + return self._count() >= 1 + + async def _execute(self, lazy=False): + if lazy: + # inject id() into return or return_set + if self._ast.return_clause: + self._ast.return_clause = ( + f"{adb.get_id_method()}({self._ast.return_clause})" + ) + else: + self._ast.additional_return = [ + f"{adb.get_id_method()}({item})" + for item in self._ast.additional_return + ] + query = self.build_query() + results, _ = await adb.cypher_query( + query, self._query_params, resolve_objects=True + ) + # The following is not as elegant as it could be but had to be copied from the + # version prior to cypher_query with the resolve_objects capability. + # It seems that certain calls are only supposed to be focusing to the first + # result item returned (?) + if results and len(results[0]) == 1: + return [n[0] for n in results] + return results + + +class AsyncBaseSet: + """ + Base class for all node sets. + + Contains common python magic methods, __len__, __contains__ etc + """ + + query_cls = AsyncQueryBuilder + + async def __aiter__(self): + async for i in await self.query_cls(self).build_ast()._execute(): + yield i + + async def __len__(self): + return await self.query_cls(self).build_ast()._count() + + async def __abool__(self): + return bool(await self.query_cls(self).build_ast()._count() > 0) + + async def __nonzero__(self): + return bool(await self.query_cls(self).build_ast()._count() > 0) + + def __contains__(self, obj): + if isinstance(obj, AsyncStructuredNode): + if hasattr(obj, "element_id") and obj.element_id is not None: + return self.query_cls(self).build_ast()._contains(obj.element_id) + raise ValueError("Unsaved node: " + repr(obj)) + + raise ValueError("Expecting StructuredNode instance") + + async def __getitem__(self, key): + if isinstance(key, slice): + if key.stop and key.start: + self.limit = key.stop - key.start + self.skip = key.start + elif key.stop: + self.limit = key.stop + elif key.start: + self.skip = key.start + + return self + + if isinstance(key, int): + self.skip = key + self.limit = 1 + + return await self.query_cls(self).build_ast()._execute()[0] + + return None + + +@dataclass +class Optional: + """Simple relation qualifier.""" + + relation: str + + +class AsyncNodeSet(AsyncBaseSet): + """ + A class representing as set of nodes matching common query parameters + """ + + def __init__(self, source): + self.source = source # could be a Traverse object or a node class + if isinstance(source, AsyncTraversal): + self.source_class = source.target_class + elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): + self.source_class = source + elif isinstance(source, AsyncStructuredNode): + self.source_class = source.__class__ + else: + raise ValueError("Bad source for nodeset " + repr(source)) + + # setup Traversal objects using relationship definitions + install_traversals(self.source_class, self) + + self.filters = [] + self.q_filters = Q() + + # used by has() + self.must_match = {} + self.dont_match = {} + + self.relations_to_fetch: list = [] + + async def _get(self, limit=None, lazy=False, **kwargs): + self.filter(**kwargs) + if limit: + self.limit = limit + return await self.query_cls(self).build_ast()._execute(lazy) + + async def get(self, lazy=False, **kwargs): + """ + Retrieve one node from the set matching supplied parameters + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :param kwargs: same syntax as `filter()` + :return: node + """ + result = await self._get(limit=2, lazy=lazy, **kwargs) + if len(result) > 1: + raise MultipleNodesReturned(repr(kwargs)) + if not result: + raise self.source_class.DoesNotExist(repr(kwargs)) + return result[0] + + async def get_or_none(self, **kwargs): + """ + Retrieve a node from the set matching supplied parameters or return none + + :param kwargs: same syntax as `filter()` + :return: node or none + """ + try: + return await self.get(**kwargs) + except self.source_class.DoesNotExist: + return None + + async def first(self, **kwargs): + """ + Retrieve the first node from the set matching supplied parameters + + :param kwargs: same syntax as `filter()` + :return: node + """ + result = await self._get(limit=1, **kwargs) + if result: + return result[0] + else: + raise self.source_class.DoesNotExist(repr(kwargs)) + + async def first_or_none(self, **kwargs): + """ + Retrieve the first node from the set matching supplied parameters or return none + + :param kwargs: same syntax as `filter()` + :return: node or none + """ + try: + return await self.first(**kwargs) + except self.source_class.DoesNotExist: + pass + return None + + def filter(self, *args, **kwargs): + """ + Apply filters to the existing nodes in the set. + + :param args: a Q object + + e.g `.filter(Q(salary__lt=10000) | Q(salary__gt=20000))`. + + :param kwargs: filter parameters + + Filters mimic Django's syntax with the double '__' to separate field and operators. + + e.g `.filter(salary__gt=20000)` results in `salary > 20000`. + + The following operators are available: + + * 'lt': less than + * 'gt': greater than + * 'lte': less than or equal to + * 'gte': greater than or equal to + * 'ne': not equal to + * 'in': matches one of list (or tuple) + * 'isnull': is null + * 'regex': matches supplied regex (neo4j regex format) + * 'exact': exactly match string (just '=') + * 'iexact': case insensitive match string + * 'contains': contains string + * 'icontains': case insensitive contains + * 'startswith': string starts with + * 'istartswith': case insensitive string starts with + * 'endswith': string ends with + * 'iendswith': case insensitive string ends with + + :return: self + """ + if args or kwargs: + self.q_filters = Q(self.q_filters & Q(*args, **kwargs)) + return self + + def exclude(self, *args, **kwargs): + """ + Exclude nodes from the NodeSet via filters. + + :param kwargs: filter parameters see syntax for the filter method + :return: self + """ + if args or kwargs: + self.q_filters = Q(self.q_filters & ~Q(*args, **kwargs)) + return self + + def has(self, **kwargs): + must_match, dont_match = process_has_args(self.source_class, kwargs) + self.must_match.update(must_match) + self.dont_match.update(dont_match) + return self + + def order_by(self, *props): + """ + Order by properties. Prepend with minus to do descending. Pass None to + remove ordering. + """ + should_remove = len(props) == 1 and props[0] is None + if not hasattr(self, "order_by_elements") or should_remove: + self.order_by_elements = [] + if should_remove: + return self + if "?" in props: + self.order_by_elements.append("?") + else: + for prop in props: + prop = prop.strip() + if prop.startswith("-"): + prop = prop[1:] + desc = True + else: + desc = False + + if prop not in self.source_class.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + + property_obj = getattr(self.source_class, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + + self.order_by_elements.append(prop + (" DESC" if desc else "")) + + return self + + def fetch_relations(self, *relation_names): + """Specify a set of relations to return.""" + relations = [] + for relation_name in relation_names: + if isinstance(relation_name, Optional): + item = {"path": relation_name.relation, "optional": True} + else: + item = {"path": relation_name} + relations.append(item) + self.relations_to_fetch = relations + return self + + +class AsyncTraversal(AsyncBaseSet): + """ + Models a traversal from a node to another. + + :param source: Starting of the traversal. + :type source: A :class:`~neomodel.core.StructuredNode` subclass, an + instance of such, a :class:`~neomodel.match.NodeSet` instance + or a :class:`~neomodel.match.Traversal` instance. + :param name: A name for the traversal. + :type name: :class:`str` + :param definition: A relationship definition that most certainly deserves + a documentation here. + :type defintion: :class:`dict` + """ + + def __init__(self, source, name, definition): + """ + Create a traversal + + """ + self.source = source + + if isinstance(source, AsyncTraversal): + self.source_class = source.target_class + elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): + self.source_class = source + elif isinstance(source, AsyncStructuredNode): + self.source_class = source.__class__ + elif isinstance(source, AsyncNodeSet): + self.source_class = source.source_class + else: + raise TypeError(f"Bad source for traversal: {type(source)}") + + invalid_keys = set(definition) - { + "direction", + "model", + "node_class", + "relation_type", + } + if invalid_keys: + raise ValueError(f"Prohibited keys in Traversal definition: {invalid_keys}") + + self.definition = definition + self.target_class = definition["node_class"] + self.name = name + self.filters = [] + + def match(self, **kwargs): + """ + Traverse relationships with properties matching the given parameters. + + e.g: `.match(price__lt=10)` + + :param kwargs: see `NodeSet.filter()` for syntax + :return: self + """ + if kwargs: + if self.definition.get("model") is None: + raise ValueError( + "match() with filter only available on relationships with a model" + ) + output = process_filter_args(self.definition["model"], kwargs) + if output: + self.filters.append(output) + return self diff --git a/neomodel/path.py b/neomodel/_async/path.py similarity index 89% rename from neomodel/path.py rename to neomodel/_async/path.py index 85a92ec7..cf53ed9d 100644 --- a/neomodel/path.py +++ b/neomodel/_async/path.py @@ -1,11 +1,10 @@ from neo4j.graph import Path from neomodel._async.core import adb -from neomodel.exceptions import RelationshipClassNotDefined -from neomodel.relationship import StructuredRel +from neomodel._async.relationship import AsyncStructuredRel -class NeomodelPath(Path): +class AsyncNeomodelPath(Path): """ Represents paths within neomodel. @@ -42,7 +41,7 @@ def __init__(self, a_neopath): if rel_type in adb._NODE_CLASS_REGISTRY: new_rel = adb._object_resolution(a_relationship) else: - new_rel = StructuredRel.inflate(a_relationship) + new_rel = AsyncStructuredRel.inflate(a_relationship) self._relationships.append(new_rel) @property diff --git a/neomodel/relationship.py b/neomodel/_async/relationship.py similarity index 95% rename from neomodel/relationship.py rename to neomodel/_async/relationship.py index d25990f0..637f35be 100644 --- a/neomodel/relationship.py +++ b/neomodel/_async/relationship.py @@ -1,5 +1,3 @@ -import warnings - from neomodel._async.core import adb from neomodel.hooks import hooks from neomodel.properties import Property, PropertyManager @@ -40,7 +38,7 @@ def __new__(mcs, name, bases, dct): StructuredRelBase = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) -class StructuredRel(StructuredRelBase): +class AsyncStructuredRel(StructuredRelBase): """ Base class for relationship objects """ @@ -103,7 +101,7 @@ def _end_node_id(self): ) @hooks - def save(self): + async def save(self): """ Save the relationship @@ -114,17 +112,17 @@ def save(self): query += "".join([f" SET r.{key} = ${key}" for key in props]) props["self"] = self.element_id - adb.cypher_query(query, props) + await adb.cypher_query(query, props) return self - def start_node(self): + async def start_node(self): """ Get start node :return: StructuredNode """ - test = adb.cypher_query( + test = await adb.cypher_query( f""" MATCH (aNode) WHERE {adb.get_id_method()}(aNode)=$start_node_element_id @@ -135,13 +133,13 @@ def start_node(self): ) return test[0][0][0] - def end_node(self): + async def end_node(self): """ Get end node :return: StructuredNode """ - return adb.cypher_query( + return await adb.cypher_query( f""" MATCH (aNode) WHERE {adb.get_id_method()}(aNode)=$end_node_element_id diff --git a/neomodel/_async/relationship_manager.py b/neomodel/_async/relationship_manager.py new file mode 100644 index 00000000..85a7ef21 --- /dev/null +++ b/neomodel/_async/relationship_manager.py @@ -0,0 +1,542 @@ +import functools +import inspect +import sys +from importlib import import_module + +from neomodel._async.core import adb +from neomodel._async.match import ( + AsyncNodeSet, + AsyncTraversal, + _rel_helper, + _rel_merge_helper, +) +from neomodel._async.relationship import AsyncStructuredRel +from neomodel.exceptions import NotConnected, RelationshipClassRedefined +from neomodel.util import ( + EITHER, + INCOMING, + OUTGOING, + _get_node_properties, + enumerate_traceback, +) + +# basestring python 3.x fallback +try: + basestring +except NameError: + basestring = str + + +# check source node is saved and not deleted +def check_source(fn): + fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ + + @functools.wraps(fn) + def checker(self, *args, **kwargs): + self.source._pre_action_check(self.name + "." + fn_name) + return fn(self, *args, **kwargs) + + return checker + + +# checks if obj is a direct subclass, 1 level +def is_direct_subclass(obj, classinfo): + for base in obj.__bases__: + if base == classinfo: + return True + return False + + +class AsyncRelationshipManager(object): + """ + Base class for all relationships managed through neomodel. + + I.e the 'friends' object in `user.friends.all()` + """ + + def __init__(self, source, key, definition): + self.source = source + self.source_class = source.__class__ + self.name = key + self.definition = definition + + def __str__(self): + direction = "either" + if self.definition["direction"] == OUTGOING: + direction = "a outgoing" + elif self.definition["direction"] == INCOMING: + direction = "a incoming" + + return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" + + def _check_node(self, obj): + """check for valid node i.e correct class and is saved""" + if not issubclass(type(obj), self.definition["node_class"]): + raise ValueError( + "Expected node of class " + self.definition["node_class"].__name__ + ) + if not hasattr(obj, "element_id"): + raise ValueError("Can't perform operation on unsaved node " + repr(obj)) + + @check_source + async def connect(self, node, properties=None): + """ + Connect a node + + :param node: + :param properties: for the new relationship + :type: dict + :return: + """ + self._check_node(node) + + if not self.definition["model"] and properties: + raise NotImplementedError( + "Relationship properties without using a relationship model " + "is no longer supported." + ) + + params = {} + rel_model = self.definition["model"] + rel_prop = None + + if rel_model: + rel_prop = {} + # need to generate defaults etc to create fake instance + tmp = rel_model(**properties) if properties else rel_model() + # build params and place holders to pass to rel_helper + for prop, val in rel_model.deflate(tmp.__properties__).items(): + if val is not None: + rel_prop[prop] = "$" + prop + else: + rel_prop[prop] = None + params[prop] = val + + if hasattr(tmp, "pre_save"): + tmp.pre_save() + + new_rel = _rel_merge_helper( + lhs="us", + rhs="them", + ident="r", + relation_properties=rel_prop, + **self.definition, + ) + q = ( + f"MATCH (them), (us) WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self " + "MERGE" + new_rel + ) + + params["them"] = node.element_id + + if not rel_model: + await self.source.cypher(q, params) + return True + + rel_ = await self.source.cypher(q + " RETURN r", params)[0][0][0] + rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node) + + if hasattr(rel_instance, "post_save"): + rel_instance.post_save() + + return rel_instance + + @check_source + async def replace(self, node, properties=None): + """ + Disconnect all existing nodes and connect the supplied node + + :param node: + :param properties: for the new relationship + :type: dict + :return: + """ + await self.disconnect_all() + await self.connect(node, properties) + + @check_source + async def relationship(self, node): + """ + Retrieve the relationship object for this first relationship between self and node. + + :param node: + :return: StructuredRel + """ + self._check_node(node) + my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) + q = ( + "MATCH " + + my_rel + + f" WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r LIMIT 1" + ) + rels = await self.source.cypher(q, {"them": node.element_id})[0] + if not rels: + return + + rel_model = self.definition.get("model") or AsyncStructuredRel + + return self._set_start_end_cls(rel_model.inflate(rels[0][0]), node) + + @check_source + async def all_relationships(self, node): + """ + Retrieve all relationship objects between self and node. + + :param node: + :return: [StructuredRel] + """ + self._check_node(node) + + my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) + q = f"MATCH {my_rel} WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r " + rels = await self.source.cypher(q, {"them": node.element_id})[0] + if not rels: + return [] + + rel_model = self.definition.get("model") or AsyncStructuredRel + return [ + self._set_start_end_cls(rel_model.inflate(rel[0]), node) for rel in rels + ] + + def _set_start_end_cls(self, rel_instance, obj): + if self.definition["direction"] == INCOMING: + rel_instance._start_node_class = obj.__class__ + rel_instance._end_node_class = self.source_class + else: + rel_instance._start_node_class = self.source_class + rel_instance._end_node_class = obj.__class__ + return rel_instance + + @check_source + async def reconnect(self, old_node, new_node): + """ + Disconnect old_node and connect new_node copying over any properties on the original relationship. + + Useful for preventing cardinality violations + + :param old_node: + :param new_node: + :return: None + """ + + self._check_node(old_node) + self._check_node(new_node) + if old_node.element_id == new_node.element_id: + return + old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition) + + # get list of properties on the existing rel + result, _ = await self.source.cypher( + f""" + MATCH (us), (old) WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old + MATCH {old_rel} RETURN r + """, + {"old": old_node.element_id}, + ) + if result: + node_properties = _get_node_properties(result[0][0]) + existing_properties = node_properties.keys() + else: + raise NotConnected("reconnect", self.source, old_node) + + # remove old relationship and create new one + new_rel = _rel_merge_helper(lhs="us", rhs="new", ident="r2", **self.definition) + q = ( + "MATCH (us), (old), (new) " + f"WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old and {adb.get_id_method()}(new)=$new " + "MATCH " + old_rel + ) + q += " MERGE" + new_rel + + # copy over properties if we have + q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) + q += " WITH r DELETE r" + + await self.source.cypher( + q, {"old": old_node.element_id, "new": new_node.element_id} + ) + + @check_source + async def disconnect(self, node): + """ + Disconnect a node + + :param node: + :return: + """ + rel = _rel_helper(lhs="a", rhs="b", ident="r", **self.definition) + q = f""" + MATCH (a), (b) WHERE {adb.get_id_method()}(a)=$self and {adb.get_id_method()}(b)=$them + MATCH {rel} DELETE r + """ + await self.source.cypher(q, {"them": node.element_id}) + + @check_source + async def disconnect_all(self): + """ + Disconnect all nodes + + :return: + """ + rhs = "b:" + self.definition["node_class"].__label__ + rel = _rel_helper(lhs="a", rhs=rhs, ident="r", **self.definition) + q = f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH " + rel + " DELETE r" + await self.source.cypher(q) + + @check_source + def _new_traversal(self): + return AsyncTraversal(self.source, self.name, self.definition) + + # The methods below simply proxy the match engine. + def get(self, **kwargs): + """ + Retrieve a related node with the matching node properties. + + :param kwargs: same syntax as `NodeSet.filter()` + :return: node + """ + return AsyncNodeSet(self._new_traversal()).get(**kwargs) + + def get_or_none(self, **kwargs): + """ + Retrieve a related node with the matching node properties or return None. + + :param kwargs: same syntax as `NodeSet.filter()` + :return: node + """ + return AsyncNodeSet(self._new_traversal()).get_or_none(**kwargs) + + def filter(self, *args, **kwargs): + """ + Retrieve related nodes matching the provided properties. + + :param args: a Q object + :param kwargs: same syntax as `NodeSet.filter()` + :return: NodeSet + """ + return AsyncNodeSet(self._new_traversal()).filter(*args, **kwargs) + + def order_by(self, *props): + """ + Order related nodes by specified properties + + :param props: + :return: NodeSet + """ + return AsyncNodeSet(self._new_traversal()).order_by(*props) + + def exclude(self, *args, **kwargs): + """ + Exclude nodes that match the provided properties. + + :param args: a Q object + :param kwargs: same syntax as `NodeSet.filter()` + :return: NodeSet + """ + return AsyncNodeSet(self._new_traversal()).exclude(*args, **kwargs) + + def is_connected(self, node): + """ + Check if a node is connected with this relationship type + :param node: + :return: bool + """ + return self._new_traversal().__contains__(node) + + def single(self): + """ + Get a single related node or none. + + :return: StructuredNode + """ + try: + return self[0] + except IndexError: + pass + + def match(self, **kwargs): + """ + Return set of nodes who's relationship properties match supplied args + + :param kwargs: same syntax as `NodeSet.filter()` + :return: NodeSet + """ + return self._new_traversal().match(**kwargs) + + def all(self): + """ + Return all related nodes. + + :return: list + """ + return self._new_traversal().all() + + def __iter__(self): + return self._new_traversal().__iter__() + + def __len__(self): + return self._new_traversal().__len__() + + def __bool__(self): + return self._new_traversal().__bool__() + + def __nonzero__(self): + return self._new_traversal().__nonzero__() + + def __contains__(self, obj): + return self._new_traversal().__contains__(obj) + + def __getitem__(self, key): + return self._new_traversal().__getitem__(key) + + +class RelationshipDefinition: + def __init__( + self, + relation_type, + cls_name, + direction, + manager=AsyncRelationshipManager, + model=None, + ): + self._validate_class(cls_name, model) + + current_frame = inspect.currentframe() + + frame_number = 3 + for i, frame in enumerate_traceback(current_frame): + if cls_name in frame.f_globals: + frame_number = i + break + self.module_name = sys._getframe(frame_number).f_globals["__name__"] + if "__file__" in sys._getframe(frame_number).f_globals: + self.module_file = sys._getframe(frame_number).f_globals["__file__"] + self._raw_class = cls_name + self.manager = manager + self.definition = { + "relation_type": relation_type, + "direction": direction, + "model": model, + } + + if model is not None: + # Relationships are easier to instantiate because + # they cannot have multiple labels. + # So, a relationship's type determines the class that should be + # instantiated uniquely. + # Here however, we still use a `frozenset([relation_type])` + # to preserve the mapping type. + label_set = frozenset([relation_type]) + try: + # If the relationship mapping exists then it is attempted + # to be redefined so that it applies to the same label. + # In this case, it has to be ensured that the class + # that is overriding the relationship is a descendant + # of the already existing class. + model_from_registry = adb._NODE_CLASS_REGISTRY[label_set] + if not issubclass(model, model_from_registry): + is_parent = issubclass(model_from_registry, model) + if is_direct_subclass(model, AsyncStructuredRel) and not is_parent: + raise RelationshipClassRedefined( + relation_type, adb._NODE_CLASS_REGISTRY, model + ) + else: + adb._NODE_CLASS_REGISTRY[label_set] = model + except KeyError: + # If the mapping does not exist then it is simply created. + adb._NODE_CLASS_REGISTRY[label_set] = model + + def _validate_class(self, cls_name, model): + if not isinstance(cls_name, (basestring, object)): + raise ValueError("Expected class name or class got " + repr(cls_name)) + + if model and not issubclass(model, (AsyncStructuredRel,)): + raise ValueError("model must be a StructuredRel") + + def lookup_node_class(self): + if not isinstance(self._raw_class, basestring): + self.definition["node_class"] = self._raw_class + else: + name = self._raw_class + if name.find(".") == -1: + module = self.module_name + else: + module, _, name = name.rpartition(".") + + if module not in sys.modules: + # yet another hack to get around python semantics + # __name__ is the namespace of the parent module for __init__.py files, + # and the namespace of the current module for other .py files, + # therefore there's a need to define the namespace differently for + # these two cases in order for . in relative imports to work correctly + # (i.e. to mean the same thing for both cases). + # For example in the comments below, namespace == myapp, always + if not hasattr(self, "module_file"): + raise ImportError(f"Couldn't lookup '{name}'") + + if "__init__.py" in self.module_file: + # e.g. myapp/__init__.py -[__name__]-> myapp + namespace = self.module_name + else: + # e.g. myapp/models.py -[__name__]-> myapp.models + namespace = self.module_name.rpartition(".")[0] + + # load a module from a namespace (e.g. models from myapp) + if module: + module = import_module(module, namespace).__name__ + # load the namespace itself (e.g. myapp) + # (otherwise it would look like import . from myapp) + else: + module = import_module(namespace).__name__ + self.definition["node_class"] = getattr(sys.modules[module], name) + + def build_manager(self, source, name): + self.lookup_node_class() + return self.manager(source, name, self.definition) + + +class AsyncZeroOrMore(AsyncRelationshipManager): + """ + A relationship of zero or more nodes (the default) + """ + + description = "zero or more relationships" + + +class RelationshipTo(RelationshipDefinition): + def __init__( + self, + cls_name, + relation_type, + cardinality=AsyncZeroOrMore, + model=None, + ): + super().__init__( + relation_type, cls_name, OUTGOING, manager=cardinality, model=model + ) + + +class RelationshipFrom(RelationshipDefinition): + def __init__( + self, + cls_name, + relation_type, + cardinality=AsyncZeroOrMore, + model=None, + ): + super().__init__( + relation_type, cls_name, INCOMING, manager=cardinality, model=model + ) + + +class Relationship(RelationshipDefinition): + def __init__( + self, + cls_name, + relation_type, + cardinality=AsyncZeroOrMore, + model=None, + ): + super().__init__( + relation_type, cls_name, EITHER, manager=cardinality, model=model + ) diff --git a/neomodel/cardinality.py b/neomodel/_sync/cardinality.py similarity index 97% rename from neomodel/cardinality.py rename to neomodel/_sync/cardinality.py index 099bf578..89fe0b30 100644 --- a/neomodel/cardinality.py +++ b/neomodel/_sync/cardinality.py @@ -1,8 +1,8 @@ -from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation -from neomodel.relationship_manager import ( # pylint:disable=unused-import +from neomodel._sync.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) +from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation class ZeroOrOne(RelationshipManager): diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py index a46867f6..443c1a8f 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/_sync/core.py @@ -253,9 +253,7 @@ def begin(self, access_mode=None, **parameters): impersonated_user=self.impersonated_user, **parameters, ) - self._active_transaction: Transaction = ( - self._session.begin_transaction() - ) + self._active_transaction: Transaction = self._session.begin_transaction() @ensure_connection def commit(self): @@ -344,9 +342,9 @@ def _object_resolution(self, object_to_resolve): return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) if isinstance(object_to_resolve, Path): - from neomodel.path import NeomodelPath + from neomodel._async.path import AsyncNeomodelPath - return NeomodelPath(object_to_resolve) + return AsyncNeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): return self._result_resolution([object_to_resolve]) @@ -835,9 +833,7 @@ def change_neo4j_password(db: Database, user, new_password): db.change_neo4j_password(user, new_password) -def clear_neo4j_database( - db: Database, clear_constraints=False, clear_indexes=False -): +def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=False): deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1115,9 +1111,9 @@ def nodes(cls): :return: NodeSet :rtype: NodeSet """ - from neomodel.match import NodeSet + from neomodel._async.match import AsyncNodeSet - return NodeSet(cls) + return AsyncNodeSet(cls) @property def element_id(self): @@ -1178,7 +1174,7 @@ def _build_merge_query( "No relation_type is specified on provided relationship" ) - from neomodel.match import _rel_helper + from neomodel._async.match import _rel_helper query_params["source_id"] = relationship.source.element_id query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " diff --git a/neomodel/match.py b/neomodel/_sync/match.py similarity index 95% rename from neomodel/match.py rename to neomodel/_sync/match.py index 65ad99b9..3926f289 100644 --- a/neomodel/match.py +++ b/neomodel/_sync/match.py @@ -4,12 +4,11 @@ from dataclasses import dataclass from typing import Optional -from neomodel._async.core import AsyncStructuredNode, adb +from neomodel._sync.core import StructuredNode, db from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty - -OUTGOING, INCOMING, EITHER = 1, -1, 0 +from neomodel.util import INCOMING, OUTGOING def _rel_helper( @@ -382,7 +381,7 @@ def build_source(self, source): return self.build_traversal(source) if isinstance(source, NodeSet): if inspect.isclass(source.source) and issubclass( - source.source, AsyncStructuredNode + source.source, StructuredNode ): ident = self.build_label(source.source.__label__.lower(), source.source) else: @@ -402,7 +401,7 @@ def build_source(self, source): ) return ident - if isinstance(source, AsyncStructuredNode): + if isinstance(source, StructuredNode): return self.build_node(source) raise ValueError("Unknown source type " + repr(source)) @@ -502,7 +501,7 @@ def build_node(self, node): place_holder = self._register_place_holder(ident) # Hack to emulate START to lookup a node by id - _node_lookup = f"MATCH ({ident}) WHERE {adb.get_id_method()}({ident})=${place_holder} WITH {ident}" + _node_lookup = f"MATCH ({ident}) WHERE {db.get_id_method()}({ident})=${place_holder} WITH {ident}" self._ast.lookup = _node_lookup self._query_params[place_holder] = node.element_id @@ -679,7 +678,7 @@ def _count(self): # drop additional_return to avoid unexpected result self._ast.additional_return = None query = self.build_query() - results, _ = adb.cypher_query(query, self._query_params) + results, _ = db.cypher_query(query, self._query_params) return int(results[0][0]) def _contains(self, node_element_id): @@ -689,7 +688,7 @@ def _contains(self, node_element_id): self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") - self._ast.where.append(f"{adb.get_id_method()}({ident}) = ${place_holder}") + self._ast.where.append(f"{db.get_id_method()}({ident}) = ${place_holder}") self._query_params[place_holder] = node_element_id return self._count() >= 1 @@ -698,15 +697,17 @@ def _execute(self, lazy=False): # inject id() into return or return_set if self._ast.return_clause: self._ast.return_clause = ( - f"{adb.get_id_method()}({self._ast.return_clause})" + f"{db.get_id_method()}({self._ast.return_clause})" ) else: self._ast.additional_return = [ - f"{adb.get_id_method()}({item})" + f"{db.get_id_method()}({item})" for item in self._ast.additional_return ] query = self.build_query() - results, _ = adb.cypher_query(query, self._query_params, resolve_objects=True) + results, _ = db.cypher_query( + query, self._query_params, resolve_objects=True + ) # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -725,29 +726,21 @@ class BaseSet: query_cls = QueryBuilder - def all(self, lazy=False): - """ - Return all nodes belonging to the set - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :return: list of nodes - :rtype: list - """ - return self.query_cls(self).build_ast()._execute(lazy) - def __iter__(self): - return (i for i in self.query_cls(self).build_ast()._execute()) + for i in self.query_cls(self).build_ast()._execute(): + yield i def __len__(self): return self.query_cls(self).build_ast()._count() - def __bool__(self): - return self.query_cls(self).build_ast()._count() > 0 + def __abool__(self): + return bool(self.query_cls(self).build_ast()._count() > 0) def __nonzero__(self): - return self.query_cls(self).build_ast()._count() > 0 + return bool(self.query_cls(self).build_ast()._count() > 0) def __contains__(self, obj): - if isinstance(obj, AsyncStructuredNode): + if isinstance(obj, StructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: return self.query_cls(self).build_ast()._contains(obj.element_id) raise ValueError("Unsaved node: " + repr(obj)) @@ -791,9 +784,9 @@ def __init__(self, source): self.source = source # could be a Traverse object or a node class if isinstance(source, Traversal): self.source_class = source.target_class - elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): + elif inspect.isclass(source) and issubclass(source, StructuredNode): self.source_class = source - elif isinstance(source, AsyncStructuredNode): + elif isinstance(source, StructuredNode): self.source_class = source.__class__ else: raise ValueError("Bad source for nodeset " + repr(source)) @@ -849,7 +842,7 @@ def first(self, **kwargs): :param kwargs: same syntax as `filter()` :return: node """ - result = result = self._get(limit=1, **kwargs) + result = self._get(limit=1, **kwargs) if result: return result[0] else: @@ -995,9 +988,9 @@ def __init__(self, source, name, definition): if isinstance(source, Traversal): self.source_class = source.target_class - elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): + elif inspect.isclass(source) and issubclass(source, StructuredNode): self.source_class = source - elif isinstance(source, AsyncStructuredNode): + elif isinstance(source, StructuredNode): self.source_class = source.__class__ elif isinstance(source, NodeSet): self.source_class = source.source_class diff --git a/neomodel/_sync/path.py b/neomodel/_sync/path.py new file mode 100644 index 00000000..6848e903 --- /dev/null +++ b/neomodel/_sync/path.py @@ -0,0 +1,53 @@ +from neo4j.graph import Path + +from neomodel._sync.core import db +from neomodel._sync.relationship import StructuredRel + + +class NeomodelPath(Path): + """ + Represents paths within neomodel. + + This object is instantiated when you include whole paths in your ``cypher_query()`` + result sets and turn ``resolve_objects`` to True. + + That is, any query of the form: + :: + + MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p + + ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already + resolved to their neomodel objects if such mapping is possible. + + + :param nodes: Neomodel nodes appearing in the path in order of appearance. + :param relationships: Neomodel relationships appearing in the path in order of appearance. + :type nodes: List[StructuredNode] + :type relationships: List[StructuredRel] + """ + + def __init__(self, a_neopath): + self._nodes = [] + self._relationships = [] + + for a_node in a_neopath.nodes: + self._nodes.append(db._object_resolution(a_node)) + + for a_relationship in a_neopath.relationships: + # This check is required here because if the relationship does not bear data + # then it does not have an entry in the registry. In that case, we instantiate + # an "unspecified" StructuredRel. + rel_type = frozenset([a_relationship.type]) + if rel_type in db._NODE_CLASS_REGISTRY: + new_rel = db._object_resolution(a_relationship) + else: + new_rel = StructuredRel.inflate(a_relationship) + self._relationships.append(new_rel) + + @property + def nodes(self): + return self._nodes + + @property + def relationships(self): + return self._relationships diff --git a/neomodel/_sync/relationship.py b/neomodel/_sync/relationship.py new file mode 100644 index 00000000..63096f5c --- /dev/null +++ b/neomodel/_sync/relationship.py @@ -0,0 +1,171 @@ +from neomodel._sync.core import db +from neomodel.hooks import hooks +from neomodel.properties import Property, PropertyManager + + +class RelationshipMeta(type): + def __new__(mcs, name, bases, dct): + inst = super().__new__(mcs, name, bases, dct) + for key, value in dct.items(): + if issubclass(value.__class__, Property): + if key == "source" or key == "target": + raise ValueError( + "Property names 'source' and 'target' are not allowed as they conflict with neomodel internals." + ) + elif key == "id": + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif key == "element_id": + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + value.name = key + value.owner = inst + + # support for 'magic' properties + if hasattr(value, "setup") and hasattr(value.setup, "__call__"): + value.setup() + return inst + + +StructuredRelBase = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) + + +class StructuredRel(StructuredRelBase): + """ + Base class for relationship objects + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def element_id(self): + return ( + int(self.element_id_property) + if db.database_version.startswith("4") + else self.element_id_property + ) + + @property + def _start_node_element_id(self): + return ( + int(self._start_node_element_id_property) + if db.database_version.startswith("4") + else self._start_node_element_id_property + ) + + @property + def _end_node_element_id(self): + return ( + int(self._end_node_element_id_property) + if db.database_version.startswith("4") + else self._end_node_element_id_property + ) + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self): + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + # Version 4.4 support - id is deprecated in version 5.x + @property + def _start_node_id(self): + try: + return int(self._start_node_element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + # Version 4.4 support - id is deprecated in version 5.x + @property + def _end_node_id(self): + try: + return int(self._end_node_element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + @hooks + def save(self): + """ + Save the relationship + + :return: self + """ + props = self.deflate(self.__properties__) + query = f"MATCH ()-[r]->() WHERE {db.get_id_method()}(r)=$self " + query += "".join([f" SET r.{key} = ${key}" for key in props]) + props["self"] = self.element_id + + db.cypher_query(query, props) + + return self + + def start_node(self): + """ + Get start node + + :return: StructuredNode + """ + test = db.cypher_query( + f""" + MATCH (aNode) + WHERE {db.get_id_method()}(aNode)=$start_node_element_id + RETURN aNode + """, + {"start_node_element_id": self._start_node_element_id}, + resolve_objects=True, + ) + return test[0][0][0] + + def end_node(self): + """ + Get end node + + :return: StructuredNode + """ + return db.cypher_query( + f""" + MATCH (aNode) + WHERE {db.get_id_method()}(aNode)=$end_node_element_id + RETURN aNode + """, + {"end_node_element_id": self._end_node_element_id}, + resolve_objects=True, + )[0][0][0] + + @classmethod + def inflate(cls, rel): + """ + Inflate a neo4j_driver relationship object to a neomodel object + :param rel: + :return: StructuredRel + """ + props = {} + for key, prop in cls.defined_properties(aliases=False, rels=False).items(): + if key in rel: + props[key] = prop.inflate(rel[key], obj=rel) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + srel = cls(**props) + srel._start_node_element_id_property = rel.start_node.element_id + srel._end_node_element_id_property = rel.end_node.element_id + srel.element_id_property = rel.element_id + return srel diff --git a/neomodel/relationship_manager.py b/neomodel/_sync/relationship_manager.py similarity index 92% rename from neomodel/relationship_manager.py rename to neomodel/_sync/relationship_manager.py index 103559a1..fb23b315 100644 --- a/neomodel/relationship_manager.py +++ b/neomodel/_sync/relationship_manager.py @@ -3,19 +3,17 @@ import sys from importlib import import_module -from neomodel._async.core import adb +from neomodel._sync.core import db +from neomodel._sync.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper +from neomodel._sync.relationship import StructuredRel from neomodel.exceptions import NotConnected, RelationshipClassRedefined -from neomodel.match import ( +from neomodel.util import ( EITHER, INCOMING, OUTGOING, - NodeSet, - Traversal, - _rel_helper, - _rel_merge_helper, + _get_node_properties, + enumerate_traceback, ) -from neomodel.relationship import StructuredRel -from neomodel.util import _get_node_properties, enumerate_traceback # basestring python 3.x fallback try: @@ -120,7 +118,7 @@ def connect(self, node, properties=None): **self.definition, ) q = ( - f"MATCH (them), (us) WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self " + f"MATCH (them), (us) WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self " "MERGE" + new_rel ) @@ -164,7 +162,7 @@ def relationship(self, node): q = ( "MATCH " + my_rel - + f" WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r LIMIT 1" + + f" WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r LIMIT 1" ) rels = self.source.cypher(q, {"them": node.element_id})[0] if not rels: @@ -185,7 +183,7 @@ def all_relationships(self, node): self._check_node(node) my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) - q = f"MATCH {my_rel} WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r " + q = f"MATCH {my_rel} WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r " rels = self.source.cypher(q, {"them": node.element_id})[0] if not rels: return [] @@ -225,7 +223,7 @@ def reconnect(self, old_node, new_node): # get list of properties on the existing rel result, _ = self.source.cypher( f""" - MATCH (us), (old) WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old + MATCH (us), (old) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old MATCH {old_rel} RETURN r """, {"old": old_node.element_id}, @@ -240,7 +238,7 @@ def reconnect(self, old_node, new_node): new_rel = _rel_merge_helper(lhs="us", rhs="new", ident="r2", **self.definition) q = ( "MATCH (us), (old), (new) " - f"WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old and {adb.get_id_method()}(new)=$new " + f"WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old and {db.get_id_method()}(new)=$new " "MATCH " + old_rel ) q += " MERGE" + new_rel @@ -249,7 +247,9 @@ def reconnect(self, old_node, new_node): q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) q += " WITH r DELETE r" - self.source.cypher(q, {"old": old_node.element_id, "new": new_node.element_id}) + self.source.cypher( + q, {"old": old_node.element_id, "new": new_node.element_id} + ) @check_source def disconnect(self, node): @@ -261,7 +261,7 @@ def disconnect(self, node): """ rel = _rel_helper(lhs="a", rhs="b", ident="r", **self.definition) q = f""" - MATCH (a), (b) WHERE {adb.get_id_method()}(a)=$self and {adb.get_id_method()}(b)=$them + MATCH (a), (b) WHERE {db.get_id_method()}(a)=$self and {db.get_id_method()}(b)=$them MATCH {rel} DELETE r """ self.source.cypher(q, {"them": node.element_id}) @@ -275,7 +275,7 @@ def disconnect_all(self): """ rhs = "b:" + self.definition["node_class"].__label__ rel = _rel_helper(lhs="a", rhs=rhs, ident="r", **self.definition) - q = f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH " + rel + " DELETE r" + q = f"MATCH (a) WHERE {db.get_id_method()}(a)=$self MATCH " + rel + " DELETE r" self.source.cypher(q) @check_source @@ -428,18 +428,18 @@ def __init__( # In this case, it has to be ensured that the class # that is overriding the relationship is a descendant # of the already existing class. - model_from_registry = adb._NODE_CLASS_REGISTRY[label_set] + model_from_registry = db._NODE_CLASS_REGISTRY[label_set] if not issubclass(model, model_from_registry): is_parent = issubclass(model_from_registry, model) if is_direct_subclass(model, StructuredRel) and not is_parent: raise RelationshipClassRedefined( - relation_type, adb._NODE_CLASS_REGISTRY, model + relation_type, db._NODE_CLASS_REGISTRY, model ) else: - adb._NODE_CLASS_REGISTRY[label_set] = model + db._NODE_CLASS_REGISTRY[label_set] = model except KeyError: # If the mapping does not exist then it is simply created. - adb._NODE_CLASS_REGISTRY[label_set] = model + db._NODE_CLASS_REGISTRY[label_set] = model def _validate_class(self, cls_name, model): if not isinstance(cls_name, (basestring, object)): diff --git a/neomodel/properties.py b/neomodel/properties.py index 737bbfac..3c88d299 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -67,7 +67,7 @@ def __init__(self, **kwargs): @property def __properties__(self): - from neomodel.relationship_manager import RelationshipManager + from neomodel._async.relationship_manager import AsyncRelationshipManager return dict( (name, value) @@ -77,7 +77,7 @@ def __properties__(self): and not isinstance( value, ( - RelationshipManager, + AsyncRelationshipManager, AliasProperty, ), ) @@ -101,7 +101,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): - from neomodel.relationship_manager import RelationshipDefinition + from neomodel._async.relationship_manager import RelationshipDefinition props = {} for baseclass in reversed(cls.__mro__): diff --git a/neomodel/util.py b/neomodel/util.py index 28435eb2..1b88e407 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -1,5 +1,7 @@ import warnings +OUTGOING, INCOMING, EITHER = 1, -1, 0 + def deprecated(message): # pylint:disable=invalid-name diff --git a/test/test_cardinality.py b/test/test_cardinality.py index 8a83c3ee..60ce9023 100644 --- a/test/test_cardinality.py +++ b/test/test_cardinality.py @@ -1,16 +1,16 @@ from pytest import raises from neomodel import ( + AsyncOne, + AsyncOneOrMore, AsyncStructuredNode, + AsyncZeroOrOne, AttemptedCardinalityViolation, CardinalityViolation, IntegerProperty, - One, - OneOrMore, RelationshipTo, StringProperty, ZeroOrMore, - ZeroOrOne, adb, ) @@ -30,9 +30,11 @@ class Car(AsyncStructuredNode): class Monkey(AsyncStructuredNode): name = StringProperty() dryers = RelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) - driver = RelationshipTo("ScrewDriver", "HAS_SCREWDRIVER", cardinality=ZeroOrOne) - car = RelationshipTo("Car", "HAS_CAR", cardinality=OneOrMore) - toothbrush = RelationshipTo("ToothBrush", "HAS_TOOTHBRUSH", cardinality=One) + driver = RelationshipTo( + "ScrewDriver", "HAS_SCREWDRIVER", cardinality=AsyncZeroOrOne + ) + car = RelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore) + toothbrush = RelationshipTo("ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne) class ToothBrush(AsyncStructuredNode): diff --git a/test/test_database_management.py b/test/test_database_management.py index 1a277d16..545e1dbf 100644 --- a/test/test_database_management.py +++ b/test/test_database_management.py @@ -3,10 +3,10 @@ from neomodel import ( AsyncStructuredNode, + AsyncStructuredRel, IntegerProperty, RelationshipTo, StringProperty, - StructuredRel, ) from neomodel._async.core import adb @@ -15,7 +15,7 @@ class City(AsyncStructuredNode): name = StringProperty() -class InCity(StructuredRel): +class InCity(AsyncStructuredRel): creation_year = IntegerProperty(index=True) diff --git a/test/test_issue283.py b/test/test_issue283.py index ebdeb97d..0efbbc48 100644 --- a/test/test_issue283.py +++ b/test/test_issue283.py @@ -25,7 +25,7 @@ # Set up a very simple model for the tests -class PersonalRelationship(neomodel.StructuredRel): +class PersonalRelationship(neomodel.AsyncStructuredRel): """ A very simple relationship between two basePersons that simply records the date at which an acquaintance was established. @@ -403,7 +403,7 @@ def test_improperly_inherited_relationship(): :return: """ - class NewRelationship(neomodel.StructuredRel): + class NewRelationship(neomodel.AsyncStructuredRel): profile_match_factor = neomodel.FloatProperty() with pytest.raises( diff --git a/test/test_issue600.py b/test/test_issue600.py index d26240f9..a85e5f01 100644 --- a/test/test_issue600.py +++ b/test/test_issue600.py @@ -18,7 +18,7 @@ basestring = str -class Class1(neomodel.StructuredRel): +class Class1(neomodel.AsyncStructuredRel): pass diff --git a/test/test_label_install.py b/test/test_label_install.py index 256ed1bd..7e367af3 100644 --- a/test/test_label_install.py +++ b/test/test_label_install.py @@ -2,9 +2,9 @@ from neomodel import ( AsyncStructuredNode, + AsyncStructuredRel, RelationshipTo, StringProperty, - StructuredRel, UniqueIdProperty, config, ) @@ -26,7 +26,7 @@ class NodeWithRelationship(AsyncStructuredNode): ... -class IndexedRelationship(StructuredRel): +class IndexedRelationship(AsyncStructuredRel): indexed_rel_prop = StringProperty(index=True) @@ -101,7 +101,7 @@ def test_install_label_twice(capsys): if adb.version_is_higher_than("5.7"): - class UniqueIndexRelationship(StructuredRel): + class UniqueIndexRelationship(AsyncStructuredRel): unique_index_rel_prop = StringProperty(unique_index=True) class OtherNodeWithUniqueIndexRelationship(AsyncStructuredNode): @@ -133,7 +133,7 @@ def test_install_labels_db_property(capsys): adb.version_is_higher_than("5.7"), reason="Not supported before 5.7" ) def test_relationship_unique_index_not_supported(): - class UniqueIndexRelationship(StructuredRel): + class UniqueIndexRelationship(AsyncStructuredRel): name = StringProperty(unique_index=True) class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): @@ -153,7 +153,7 @@ class NodeWithUniqueIndexRelationship(AsyncStructuredNode): @pytest.mark.skipif(not adb.version_is_higher_than("5.7"), reason="Supported from 5.7") def test_relationship_unique_index(): - class UniqueIndexRelationshipBis(StructuredRel): + class UniqueIndexRelationshipBis(AsyncStructuredRel): name = StringProperty(unique_index=True) class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): diff --git a/test/test_match_api.py b/test/test_match_api.py index 50828523..f1421df8 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -5,19 +5,24 @@ from neomodel import ( INCOMING, AsyncStructuredNode, + AsyncStructuredRel, DateTimeProperty, IntegerProperty, Q, RelationshipFrom, RelationshipTo, StringProperty, - StructuredRel, +) +from neomodel._async.match import ( + AsyncNodeSet, + AsyncQueryBuilder, + AsyncTraversal, + Optional, ) from neomodel.exceptions import MultipleNodesReturned -from neomodel.match import NodeSet, Optional, QueryBuilder, Traversal -class SupplierRel(StructuredRel): +class SupplierRel(AsyncStructuredRel): since = DateTimeProperty(default=datetime.now) courier = StringProperty() @@ -30,14 +35,14 @@ class Supplier(AsyncStructuredNode): class Species(AsyncStructuredNode): name = StringProperty() - coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) + coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=AsyncStructuredRel) class Coffee(AsyncStructuredNode): name = StringProperty(unique_index=True) price = IntegerProperty() suppliers = RelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) - species = RelationshipTo(Species, "COFFEE SPECIES", model=StructuredRel) + species = RelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel) id_ = IntegerProperty() @@ -48,8 +53,8 @@ class Extension(AsyncStructuredNode): def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() - node_set = NodeSet(Coffee) - qb = QueryBuilder(node_set).build_ast() + node_set = AsyncNodeSet(Coffee) + qb = AsyncQueryBuilder(node_set).build_ast() results = qb._execute() @@ -62,7 +67,7 @@ def test_filter_exclude_via_labels(): # with filter and exclude Coffee(name="Kenco", price=3).save() node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") - qb = QueryBuilder(node_set).build_ast() + qb = AsyncQueryBuilder(node_set).build_ast() results = qb._execute() assert "(coffee:Coffee)" in qb._ast.match @@ -76,16 +81,16 @@ def test_simple_has_via_label(): tesco = Supplier(name="Tesco", delivery_cost=2).save() nescafe.suppliers.connect(tesco) - ns = NodeSet(Coffee).has(suppliers=True) - qb = QueryBuilder(ns).build_ast() + ns = AsyncNodeSet(Coffee).has(suppliers=True) + qb = AsyncQueryBuilder(ns).build_ast() results = qb._execute() assert "COFFEE SUPPLIERS" in qb._ast.where[0] assert len(results) == 1 assert results[0].name == "Nescafe" Coffee(name="nespresso", price=99).save() - ns = NodeSet(Coffee).has(suppliers=False) - qb = QueryBuilder(ns).build_ast() + ns = AsyncNodeSet(Coffee).has(suppliers=False) + qb = AsyncQueryBuilder(ns).build_ast() results = qb._execute() assert len(results) > 0 assert "NOT" in qb._ast.where[0] @@ -109,7 +114,9 @@ def test_simple_traverse_with_filter(): tesco = Supplier(name="Sainsburys", delivery_cost=2).save() nescafe.suppliers.connect(tesco) - qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) + qb = AsyncQueryBuilder( + AsyncNodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) + ) results = qb.build_ast()._execute() @@ -126,8 +133,8 @@ def test_double_traverse(): nescafe.suppliers.connect(tesco) tesco.coffees.connect(Coffee(name="Decafe", price=2).save()) - ns = NodeSet(NodeSet(source=nescafe).suppliers.match()).coffees.match() - qb = QueryBuilder(ns).build_ast() + ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match() + qb = AsyncQueryBuilder(ns).build_ast() results = qb._execute() assert len(results) == 2 @@ -137,14 +144,14 @@ def test_double_traverse(): def test_count(): Coffee(name="Nescafe Gold", price=99).save() - count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count() + count = AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast()._count() assert count > 0 Coffee(name="Kawa", price=27).save() - node_set = NodeSet(source=Coffee) + node_set = AsyncNodeSet(source=Coffee) node_set.skip = 1 node_set.limit = 1 - count = QueryBuilder(node_set).build_ast()._count() + count = AsyncQueryBuilder(node_set).build_ast()._count() assert count == 1 @@ -226,13 +233,13 @@ def test_order_by(): assert Coffee.nodes.order_by("-price").all()[0].price == 35 ns = Coffee.nodes.order_by("-price") - qb = QueryBuilder(ns).build_ast() + qb = AsyncQueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) - qb = QueryBuilder(ns).build_ast() + qb = AsyncQueryBuilder(ns).build_ast() assert not qb._ast.order_by ns = ns.order_by("?") - qb = QueryBuilder(ns).build_ast() + qb = AsyncQueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" assert qb._ast.order_by == "r" @@ -294,7 +301,7 @@ def test_traversal_definition_keys_are_valid(): muckefuck = Coffee(name="Mukkefuck", price=1) with raises(ValueError): - Traversal( + AsyncTraversal( muckefuck, "a_name", { @@ -305,7 +312,7 @@ def test_traversal_definition_keys_are_valid(): }, ) - Traversal( + AsyncTraversal( muckefuck, "a_name", { @@ -456,7 +463,9 @@ def test_traversal_filter_left_hand_statement(): nescafe_gold.suppliers.connect(lidl) lidl_supplier = ( - NodeSet(Coffee.nodes.filter(price=11).suppliers).filter(delivery_cost=3).all() + AsyncNodeSet(Coffee.nodes.filter(price=11).suppliers) + .filter(delivery_cost=3) + .all() ) assert lidl in lidl_supplier diff --git a/test/test_migration_neo4j_5.py b/test/test_migration_neo4j_5.py index a5efe800..ff869545 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/test_migration_neo4j_5.py @@ -2,10 +2,10 @@ from neomodel import ( AsyncStructuredNode, + AsyncStructuredRel, IntegerProperty, RelationshipTo, StringProperty, - StructuredRel, ) from neomodel._async.core import adb @@ -14,7 +14,7 @@ class Album(AsyncStructuredNode): name = StringProperty() -class Released(StructuredRel): +class Released(AsyncStructuredRel): year = IntegerProperty() diff --git a/test/test_models.py b/test/test_models.py index fb773055..20e2bccc 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -6,10 +6,10 @@ from neomodel import ( AsyncStructuredNode, + AsyncStructuredRel, DateProperty, IntegerProperty, StringProperty, - StructuredRel, ) from neomodel._async.core import adb from neomodel.exceptions import RequiredProperty, UniqueProperty @@ -325,21 +325,21 @@ class ReservedPropertiesElementIdNode(AsyncStructuredNode): with raises(ValueError, match=error_match): - class ReservedPropertiesIdRel(StructuredRel): + class ReservedPropertiesIdRel(AsyncStructuredRel): id = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesElementIdRel(StructuredRel): + class ReservedPropertiesElementIdRel(AsyncStructuredRel): element_id = StringProperty() error_match = r"Property names 'source' and 'target' are not allowed as they conflict with neomodel internals." with raises(ValueError, match=error_match): - class ReservedPropertiesSourceRel(StructuredRel): + class ReservedPropertiesSourceRel(AsyncStructuredRel): source = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesTargetRel(StructuredRel): + class ReservedPropertiesTargetRel(AsyncStructuredRel): target = StringProperty() diff --git a/test/test_paths.py b/test/test_paths.py index f6f8bbbc..ba703b7f 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -1,16 +1,16 @@ from neomodel import ( + AsyncNeomodelPath, AsyncStructuredNode, + AsyncStructuredRel, IntegerProperty, - NeomodelPath, RelationshipTo, StringProperty, - StructuredRel, UniqueIdProperty, adb, ) -class PersonLivesInCity(StructuredRel): +class PersonLivesInCity(AsyncStructuredRel): """ Relationship with data that will be instantiated as "stand-alone" """ @@ -75,13 +75,13 @@ def test_path_instantiation(): path_nodes = path_object.nodes path_rels = path_object.relationships - assert type(path_object) is NeomodelPath + assert type(path_object) is AsyncNeomodelPath assert type(path_nodes[0]) is CityOfResidence assert type(path_nodes[1]) is PersonOfInterest assert type(path_nodes[2]) is CountryOfOrigin assert type(path_rels[0]) is PersonLivesInCity - assert type(path_rels[1]) is StructuredRel + assert type(path_rels[1]) is AsyncStructuredRel c1.delete() c2.delete() diff --git a/test/test_relationship_models.py b/test/test_relationship_models.py index 2e07b684..cdb27995 100644 --- a/test/test_relationship_models.py +++ b/test/test_relationship_models.py @@ -5,18 +5,18 @@ from neomodel import ( AsyncStructuredNode, + AsyncStructuredRel, DateTimeProperty, DeflateError, Relationship, RelationshipTo, StringProperty, - StructuredRel, ) HOOKS_CALLED = {"pre_save": 0, "post_save": 0} -class FriendRel(StructuredRel): +class FriendRel(AsyncStructuredRel): since = DateTimeProperty(default=lambda: datetime.now(pytz.utc)) diff --git a/test/test_relationships.py b/test/test_relationships.py index 4c047eaf..6b81d137 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -1,15 +1,15 @@ from pytest import raises from neomodel import ( + AsyncOne, AsyncStructuredNode, + AsyncStructuredRel, IntegerProperty, - One, Q, Relationship, RelationshipFrom, RelationshipTo, StringProperty, - StructuredRel, ) from neomodel._async.core import adb @@ -31,7 +31,7 @@ def special_power(self): class Country(AsyncStructuredNode): code = StringProperty(unique_index=True) inhabitant = RelationshipFrom(PersonWithRels, "IS_FROM") - president = RelationshipTo(PersonWithRels, "PRESIDENT", cardinality=One) + president = RelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne) class SuperHero(PersonWithRels): @@ -99,10 +99,10 @@ def test_either_direction_connect(): assert int(result[0][0]) == 1 rel = rey.knows.relationship(sakis) - assert isinstance(rel, StructuredRel) + assert isinstance(rel, AsyncStructuredRel) rels = rey.knows.all_relationships(sakis) - assert isinstance(rels[0], StructuredRel) + assert isinstance(rels[0], AsyncStructuredRel) def test_search_and_filter_and_exclude(): diff --git a/test/test_scripts.py b/test/test_scripts.py index f925603e..9099baf8 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -4,15 +4,15 @@ from neomodel import ( AsyncStructuredNode, + AsyncStructuredRel, RelationshipTo, StringProperty, - StructuredRel, config, ) from neomodel._async.core import adb -class ScriptsTestRel(StructuredRel): +class ScriptsTestRel(AsyncStructuredRel): some_unique_property = StringProperty( unique_index=adb.version_is_higher_than("5.7") ) From 0c843520ce52e0f29b12e9b206d9bf001e0b6964 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 20 Dec 2023 15:13:25 +0100 Subject: [PATCH 14/73] Make Relationships async --- neomodel/__init__.py | 17 +++++++++++------ neomodel/_async/relationship_manager.py | 8 ++++---- neomodel/_sync/core.py | 18 +++++++++++------- neomodel/contrib/spatial_properties.py | 4 ++-- neomodel/properties.py | 4 ++-- test/test_batch.py | 8 ++++---- test/test_cardinality.py | 12 +++++++----- test/test_database_management.py | 4 ++-- test/test_issue112.py | 4 ++-- test/test_issue283.py | 6 +++--- test/test_issue600.py | 12 ++++++------ test/test_label_install.py | 10 +++++----- test/test_match_api.py | 16 +++++++++------- test/test_migration_neo4j_5.py | 4 ++-- test/test_paths.py | 8 ++++---- test/test_relationship_models.py | 10 +++++----- test/test_relationships.py | 14 +++++++------- test/test_relative_relationships.py | 4 ++-- test/test_scripts.py | 4 ++-- 19 files changed, 90 insertions(+), 77 deletions(-) diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 0997322d..21d6c668 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -20,19 +20,24 @@ from neomodel._async.path import AsyncNeomodelPath from neomodel._async.relationship import AsyncStructuredRel from neomodel._async.relationship_manager import ( + AsyncRelationship, + AsyncRelationshipDefinition, + AsyncRelationshipFrom, AsyncRelationshipManager, - NotConnected, - Relationship, - RelationshipDefinition, - RelationshipFrom, - RelationshipTo, + AsyncRelationshipTo, ) from neomodel._sync.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne from neomodel._sync.core import StructuredNode from neomodel._sync.match import NodeSet, Traversal from neomodel._sync.path import NeomodelPath from neomodel._sync.relationship import StructuredRel -from neomodel._sync.relationship_manager import RelationshipManager +from neomodel._sync.relationship_manager import ( + Relationship, + RelationshipDefinition, + RelationshipFrom, + RelationshipManager, + RelationshipTo, +) from neomodel.exceptions import * from neomodel.match_q import Q # noqa from neomodel.properties import ( diff --git a/neomodel/_async/relationship_manager.py b/neomodel/_async/relationship_manager.py index 85a7ef21..3bc587c3 100644 --- a/neomodel/_async/relationship_manager.py +++ b/neomodel/_async/relationship_manager.py @@ -390,7 +390,7 @@ def __getitem__(self, key): return self._new_traversal().__getitem__(key) -class RelationshipDefinition: +class AsyncRelationshipDefinition: def __init__( self, relation_type, @@ -503,7 +503,7 @@ class AsyncZeroOrMore(AsyncRelationshipManager): description = "zero or more relationships" -class RelationshipTo(RelationshipDefinition): +class AsyncRelationshipTo(AsyncRelationshipDefinition): def __init__( self, cls_name, @@ -516,7 +516,7 @@ def __init__( ) -class RelationshipFrom(RelationshipDefinition): +class AsyncRelationshipFrom(AsyncRelationshipDefinition): def __init__( self, cls_name, @@ -529,7 +529,7 @@ def __init__( ) -class Relationship(RelationshipDefinition): +class AsyncRelationship(AsyncRelationshipDefinition): def __init__( self, cls_name, diff --git a/neomodel/_sync/core.py b/neomodel/_sync/core.py index 443c1a8f..56f69008 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/_sync/core.py @@ -253,7 +253,9 @@ def begin(self, access_mode=None, **parameters): impersonated_user=self.impersonated_user, **parameters, ) - self._active_transaction: Transaction = self._session.begin_transaction() + self._active_transaction: Transaction = ( + self._session.begin_transaction() + ) @ensure_connection def commit(self): @@ -342,9 +344,9 @@ def _object_resolution(self, object_to_resolve): return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) if isinstance(object_to_resolve, Path): - from neomodel._async.path import AsyncNeomodelPath + from neomodel._sync.path import NeomodelPath - return AsyncNeomodelPath(object_to_resolve) + return NeomodelPath(object_to_resolve) if isinstance(object_to_resolve, list): return self._result_resolution([object_to_resolve]) @@ -833,7 +835,9 @@ def change_neo4j_password(db: Database, user, new_password): db.change_neo4j_password(user, new_password) -def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=False): +def clear_neo4j_database( + db: Database, clear_constraints=False, clear_indexes=False +): deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). @@ -1111,9 +1115,9 @@ def nodes(cls): :return: NodeSet :rtype: NodeSet """ - from neomodel._async.match import AsyncNodeSet + from neomodel._sync.match import NodeSet - return AsyncNodeSet(cls) + return NodeSet(cls) @property def element_id(self): @@ -1174,7 +1178,7 @@ def _build_merge_query( "No relation_type is specified on provided relationship" ) - from neomodel._async.match import _rel_helper + from neomodel._sync.match import _rel_helper query_params["source_id"] = relationship.source.element_id query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " diff --git a/neomodel/contrib/spatial_properties.py b/neomodel/contrib/spatial_properties.py index 7a48018b..6b6606d5 100644 --- a/neomodel/contrib/spatial_properties.py +++ b/neomodel/contrib/spatial_properties.py @@ -25,9 +25,9 @@ # If shapely is not installed, its import will fail and the spatial properties will not be available try: - from shapely.geometry import Point as ShapelyPoint from shapely import __version__ as shapely_version - from shapely.coords import CoordinateSequence + from shapely.coords import CoordinateSequence + from shapely.geometry import Point as ShapelyPoint except ImportError as exc: raise ImportError( "NEOMODEL ERROR: Shapely not found. If required, you can install Shapely via " diff --git a/neomodel/properties.py b/neomodel/properties.py index 3c88d299..df80a47e 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -101,7 +101,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): - from neomodel._async.relationship_manager import RelationshipDefinition + from neomodel._async.relationship_manager import AsyncRelationshipDefinition props = {} for baseclass in reversed(cls.__mro__): @@ -115,7 +115,7 @@ def defined_properties(cls, aliases=True, properties=True, rels=True): and isinstance(property, Property) and not isinstance(property, AliasProperty) ) - or (rels and isinstance(property, RelationshipDefinition)) + or (rels and isinstance(property, AsyncRelationshipDefinition)) ) ) return props diff --git a/test/test_batch.py b/test/test_batch.py index fc582509..8280805a 100644 --- a/test/test_batch.py +++ b/test/test_batch.py @@ -1,10 +1,10 @@ from pytest import raises from neomodel import ( + AsyncRelationshipFrom, + AsyncRelationshipTo, AsyncStructuredNode, IntegerProperty, - RelationshipFrom, - RelationshipTo, StringProperty, UniqueIdProperty, config, @@ -99,12 +99,12 @@ def test_batch_index_violation(): class Dog(AsyncStructuredNode): name = StringProperty(required=True) - owner = RelationshipTo("Person", "owner") + owner = AsyncRelationshipTo("Person", "owner") class Person(AsyncStructuredNode): name = StringProperty(unique_index=True) - pets = RelationshipFrom("Dog", "owner") + pets = AsyncRelationshipFrom("Dog", "owner") def test_get_or_create_with_rel(): diff --git a/test/test_cardinality.py b/test/test_cardinality.py index 60ce9023..6be9226f 100644 --- a/test/test_cardinality.py +++ b/test/test_cardinality.py @@ -3,12 +3,12 @@ from neomodel import ( AsyncOne, AsyncOneOrMore, + AsyncRelationshipTo, AsyncStructuredNode, AsyncZeroOrOne, AttemptedCardinalityViolation, CardinalityViolation, IntegerProperty, - RelationshipTo, StringProperty, ZeroOrMore, adb, @@ -29,12 +29,14 @@ class Car(AsyncStructuredNode): class Monkey(AsyncStructuredNode): name = StringProperty() - dryers = RelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) - driver = RelationshipTo( + dryers = AsyncRelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) + driver = AsyncRelationshipTo( "ScrewDriver", "HAS_SCREWDRIVER", cardinality=AsyncZeroOrOne ) - car = RelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore) - toothbrush = RelationshipTo("ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne) + car = AsyncRelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore) + toothbrush = AsyncRelationshipTo( + "ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne + ) class ToothBrush(AsyncStructuredNode): diff --git a/test/test_database_management.py b/test/test_database_management.py index 545e1dbf..9791d3a4 100644 --- a/test/test_database_management.py +++ b/test/test_database_management.py @@ -2,10 +2,10 @@ from neo4j.exceptions import AuthError from neomodel import ( + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, IntegerProperty, - RelationshipTo, StringProperty, ) from neomodel._async.core import adb @@ -22,7 +22,7 @@ class InCity(AsyncStructuredRel): class Venue(AsyncStructuredNode): name = StringProperty(unique_index=True) creator = StringProperty(index=True) - in_city = RelationshipTo(City, relation_type="IN", model=InCity) + in_city = AsyncRelationshipTo(City, relation_type="IN", model=InCity) def test_clear_database(): diff --git a/test/test_issue112.py b/test/test_issue112.py index c24fe1b2..d20b53ac 100644 --- a/test/test_issue112.py +++ b/test/test_issue112.py @@ -1,8 +1,8 @@ -from neomodel import AsyncStructuredNode, RelationshipTo +from neomodel import AsyncRelationshipTo, AsyncStructuredNode class SomeModel(AsyncStructuredNode): - test = RelationshipTo("SomeModel", "SELF") + test = AsyncRelationshipTo("SomeModel", "SELF") def test_len_relationship(): diff --git a/test/test_issue283.py b/test/test_issue283.py index 0efbbc48..fb5b5f2a 100644 --- a/test/test_issue283.py +++ b/test/test_issue283.py @@ -42,7 +42,7 @@ class BasePerson(neomodel.AsyncStructuredNode): """ name = neomodel.StringProperty(required=True, unique_index=True) - friends_with = neomodel.RelationshipTo( + friends_with = neomodel.AsyncRelationshipTo( "BasePerson", "FRIENDS_WITH", model=PersonalRelationship ) @@ -374,7 +374,7 @@ class ExtendedPersonalRelationship(PersonalRelationship): # Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson class ExtendedSomePerson(SomePerson): - friends_with = neomodel.RelationshipTo( + friends_with = neomodel.AsyncRelationshipTo( "BaseOtherPerson", "FRIENDS_WITH", model=ExtendedPersonalRelationship, @@ -412,7 +412,7 @@ class NewRelationship(neomodel.AsyncStructuredRel): ): class NewSomePerson(SomePerson): - friends_with = neomodel.RelationshipTo( + friends_with = neomodel.AsyncRelationshipTo( "BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship ) diff --git a/test/test_issue600.py b/test/test_issue600.py index a85e5f01..377dc700 100644 --- a/test/test_issue600.py +++ b/test/test_issue600.py @@ -31,25 +31,25 @@ class SubClass2(Class1): class RelationshipDefinerSecondSibling(neomodel.AsyncStructuredNode): - rel_1 = neomodel.Relationship( + rel_1 = neomodel.AsyncRelationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 ) - rel_2 = neomodel.Relationship( + rel_2 = neomodel.AsyncRelationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1 ) - rel_3 = neomodel.Relationship( + rel_3 = neomodel.AsyncRelationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2 ) class RelationshipDefinerParentLast(neomodel.AsyncStructuredNode): - rel_2 = neomodel.Relationship( + rel_2 = neomodel.AsyncRelationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 ) - rel_3 = neomodel.Relationship( + rel_3 = neomodel.AsyncRelationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2 ) - rel_1 = neomodel.Relationship( + rel_1 = neomodel.AsyncRelationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1 ) diff --git a/test/test_label_install.py b/test/test_label_install.py index 7e367af3..771eddab 100644 --- a/test/test_label_install.py +++ b/test/test_label_install.py @@ -1,9 +1,9 @@ import pytest from neomodel import ( + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, - RelationshipTo, StringProperty, UniqueIdProperty, config, @@ -31,7 +31,7 @@ class IndexedRelationship(AsyncStructuredRel): class OtherNodeWithRelationship(AsyncStructuredNode): - has_rel = RelationshipTo( + has_rel = AsyncRelationshipTo( NodeWithRelationship, "INDEXED_REL", model=IndexedRelationship ) @@ -105,7 +105,7 @@ class UniqueIndexRelationship(AsyncStructuredRel): unique_index_rel_prop = StringProperty(unique_index=True) class OtherNodeWithUniqueIndexRelationship(AsyncStructuredNode): - has_rel = RelationshipTo( + has_rel = AsyncRelationshipTo( NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship ) @@ -144,7 +144,7 @@ class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): ): class NodeWithUniqueIndexRelationship(AsyncStructuredNode): - has_rel = RelationshipTo( + has_rel = AsyncRelationshipTo( TargetNodeForUniqueIndexRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship, @@ -160,7 +160,7 @@ class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): pass class NodeWithUniqueIndexRelationship(AsyncStructuredNode): - has_rel = RelationshipTo( + has_rel = AsyncRelationshipTo( TargetNodeForUniqueIndexRelationship, "UNIQUE_INDEX_REL_BIS", model=UniqueIndexRelationshipBis, diff --git a/test/test_match_api.py b/test/test_match_api.py index f1421df8..8cbe3e1a 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -4,13 +4,13 @@ from neomodel import ( INCOMING, + AsyncRelationshipFrom, + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, DateTimeProperty, IntegerProperty, Q, - RelationshipFrom, - RelationshipTo, StringProperty, ) from neomodel._async.match import ( @@ -30,24 +30,26 @@ class SupplierRel(AsyncStructuredRel): class Supplier(AsyncStructuredNode): name = StringProperty() delivery_cost = IntegerProperty() - coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS") + coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS") class Species(AsyncStructuredNode): name = StringProperty() - coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=AsyncStructuredRel) + coffees = AsyncRelationshipFrom( + "Coffee", "COFFEE SPECIES", model=AsyncStructuredRel + ) class Coffee(AsyncStructuredNode): name = StringProperty(unique_index=True) price = IntegerProperty() - suppliers = RelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) - species = RelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel) + suppliers = AsyncRelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) + species = AsyncRelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel) id_ = IntegerProperty() class Extension(AsyncStructuredNode): - extension = RelationshipTo("Extension", "extension") + extension = AsyncRelationshipTo("Extension", "extension") def test_filter_exclude_via_labels(): diff --git a/test/test_migration_neo4j_5.py b/test/test_migration_neo4j_5.py index ff869545..ee4c3ed2 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/test_migration_neo4j_5.py @@ -1,10 +1,10 @@ import pytest from neomodel import ( + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, IntegerProperty, - RelationshipTo, StringProperty, ) from neomodel._async.core import adb @@ -20,7 +20,7 @@ class Released(AsyncStructuredRel): class Band(AsyncStructuredNode): name = StringProperty() - released = RelationshipTo(Album, relation_type="RELEASED", model=Released) + released = AsyncRelationshipTo(Album, relation_type="RELEASED", model=Released) def test_read_elements_id(): diff --git a/test/test_paths.py b/test/test_paths.py index ba703b7f..9c0000f8 100644 --- a/test/test_paths.py +++ b/test/test_paths.py @@ -1,9 +1,9 @@ from neomodel import ( AsyncNeomodelPath, + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, IntegerProperty, - RelationshipTo, StringProperty, UniqueIdProperty, adb, @@ -24,7 +24,7 @@ class CountryOfOrigin(AsyncStructuredNode): class CityOfResidence(AsyncStructuredNode): name = StringProperty(required=True) - country = RelationshipTo(CountryOfOrigin, "FROM_COUNTRY") + country = AsyncRelationshipTo(CountryOfOrigin, "FROM_COUNTRY") class PersonOfInterest(AsyncStructuredNode): @@ -32,8 +32,8 @@ class PersonOfInterest(AsyncStructuredNode): name = StringProperty(unique_index=True) age = IntegerProperty(index=True, default=0) - country = RelationshipTo(CountryOfOrigin, "IS_FROM") - city = RelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) + country = AsyncRelationshipTo(CountryOfOrigin, "IS_FROM") + city = AsyncRelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) def test_path_instantiation(): diff --git a/test/test_relationship_models.py b/test/test_relationship_models.py index cdb27995..89b50c53 100644 --- a/test/test_relationship_models.py +++ b/test/test_relationship_models.py @@ -4,12 +4,12 @@ from pytest import raises from neomodel import ( + AsyncRelationship, + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, DateTimeProperty, DeflateError, - Relationship, - RelationshipTo, StringProperty, ) @@ -32,13 +32,13 @@ def post_save(self): class Badger(AsyncStructuredNode): name = StringProperty(unique_index=True) - friend = Relationship("Badger", "FRIEND", model=FriendRel) - hates = RelationshipTo("Stoat", "HATES", model=HatesRel) + friend = AsyncRelationship("Badger", "FRIEND", model=FriendRel) + hates = AsyncRelationshipTo("Stoat", "HATES", model=HatesRel) class Stoat(AsyncStructuredNode): name = StringProperty(unique_index=True) - hates = RelationshipTo("Badger", "HATES", model=HatesRel) + hates = AsyncRelationshipTo("Badger", "HATES", model=HatesRel) def test_either_connect_with_rel_model(): diff --git a/test/test_relationships.py b/test/test_relationships.py index 6b81d137..69196eaf 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -2,13 +2,13 @@ from neomodel import ( AsyncOne, + AsyncRelationship, + AsyncRelationshipFrom, + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, IntegerProperty, Q, - Relationship, - RelationshipFrom, - RelationshipTo, StringProperty, ) from neomodel._async.core import adb @@ -17,8 +17,8 @@ class PersonWithRels(AsyncStructuredNode): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) - is_from = RelationshipTo("Country", "IS_FROM") - knows = Relationship("PersonWithRels", "KNOWS") + is_from = AsyncRelationshipTo("Country", "IS_FROM") + knows = AsyncRelationship("PersonWithRels", "KNOWS") @property def special_name(self): @@ -30,8 +30,8 @@ def special_power(self): class Country(AsyncStructuredNode): code = StringProperty(unique_index=True) - inhabitant = RelationshipFrom(PersonWithRels, "IS_FROM") - president = RelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne) + inhabitant = AsyncRelationshipFrom(PersonWithRels, "IS_FROM") + president = AsyncRelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne) class SuperHero(PersonWithRels): diff --git a/test/test_relative_relationships.py b/test/test_relative_relationships.py index 0a82ff57..619d9d9e 100644 --- a/test/test_relative_relationships.py +++ b/test/test_relative_relationships.py @@ -1,11 +1,11 @@ -from neomodel import AsyncStructuredNode, RelationshipTo, StringProperty +from neomodel import AsyncRelationshipTo, AsyncStructuredNode, StringProperty from neomodel.test_relationships import Country class Cat(AsyncStructuredNode): name = StringProperty() # Relationship is defined using a relative class path - is_from = RelationshipTo(".test_relationships.Country", "IS_FROM") + is_from = AsyncRelationshipTo(".test_relationships.Country", "IS_FROM") def test_relative_relationship(): diff --git a/test/test_scripts.py b/test/test_scripts.py index 9099baf8..d054ddb5 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -3,9 +3,9 @@ import pytest from neomodel import ( + AsyncRelationshipTo, AsyncStructuredNode, AsyncStructuredRel, - RelationshipTo, StringProperty, config, ) @@ -22,7 +22,7 @@ class ScriptsTestRel(AsyncStructuredRel): class ScriptsTestNode(AsyncStructuredNode): personal_id = StringProperty(unique_index=True) name = StringProperty(index=True) - rel = RelationshipTo("ScriptsTestNode", "REL", model=ScriptsTestRel) + rel = AsyncRelationshipTo("ScriptsTestNode", "REL", model=ScriptsTestRel) def test_neomodel_install_labels(): From a95ea04f4feccdb78aa3f9e7037de8847616f9b7 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 22 Dec 2023 15:46:16 +0100 Subject: [PATCH 15/73] Rename _async to sync_ ; Update autodoc --- .pre-commit-config.yaml | 2 +- bin/make-unasync | 16 +++-- doc/source/index.rst | 2 + doc/source/module_documentation.rst | 70 +++---------------- doc/source/module_documentation_async.rst | 49 +++++++++++++ doc/source/module_documentation_sync.rst | 49 +++++++++++++ neomodel/__init__.py | 36 +++++----- neomodel/_async/__init__.py | 1 - neomodel/_sync/__init__.py | 1 - neomodel/async_/__init__.py | 1 + neomodel/{_async => async_}/cardinality.py | 2 +- neomodel/{_async => async_}/core.py | 6 +- neomodel/{_async => async_}/match.py | 2 +- neomodel/{_async => async_}/path.py | 4 +- neomodel/{_async => async_}/relationship.py | 2 +- .../relationship_manager.py | 6 +- neomodel/contrib/__init__.py | 3 +- neomodel/contrib/async_/semi_structured.py | 64 +++++++++++++++++ .../contrib/{ => sync_}/semi_structured.py | 4 +- neomodel/integration/numpy.py | 2 +- neomodel/integration/pandas.py | 2 +- neomodel/properties.py | 6 +- neomodel/scripts/neomodel_inspect_database.py | 4 +- neomodel/scripts/neomodel_install_labels.py | 2 +- neomodel/scripts/neomodel_remove_labels.py | 2 +- neomodel/sync_/__init__.py | 1 + neomodel/{_sync => sync_}/cardinality.py | 2 +- neomodel/{_sync => sync_}/core.py | 6 +- neomodel/{_sync => sync_}/match.py | 2 +- neomodel/{_sync => sync_}/path.py | 4 +- neomodel/{_sync => sync_}/relationship.py | 2 +- .../{_sync => sync_}/relationship_manager.py | 11 ++- test/_async_compat/mark_decorator.py | 4 +- test/async_/conftest.py | 4 +- test/async_/test_alias.py | 34 +++++++++ test/async_/test_cypher.py | 4 +- test/sync/conftest.py | 4 +- test/{ => sync}/test_alias.py | 8 ++- test/sync/test_cypher.py | 4 +- test/test_database_management.py | 2 +- test/test_dbms_awareness.py | 2 +- test/test_driver_options.py | 2 +- test/test_indexing.py | 2 +- test/test_label_drop.py | 2 +- test/test_label_install.py | 2 +- test/test_match_api.py | 2 +- test/test_migration_neo4j_5.py | 2 +- test/test_models.py | 2 +- test/test_relationships.py | 2 +- test/test_scripts.py | 2 +- test/test_transactions.py | 2 +- 51 files changed, 310 insertions(+), 142 deletions(-) create mode 100644 doc/source/module_documentation_async.rst create mode 100644 doc/source/module_documentation_sync.rst delete mode 100644 neomodel/_async/__init__.py delete mode 100644 neomodel/_sync/__init__.py create mode 100644 neomodel/async_/__init__.py rename neomodel/{_async => async_}/cardinality.py (98%) rename neomodel/{_async => async_}/core.py (99%) rename neomodel/{_async => async_}/match.py (99%) rename neomodel/{_async => async_}/path.py (94%) rename neomodel/{_async => async_}/relationship.py (99%) rename neomodel/{_async => async_}/relationship_manager.py (99%) create mode 100644 neomodel/contrib/async_/semi_structured.py rename neomodel/contrib/{ => sync_}/semi_structured.py (95%) create mode 100644 neomodel/sync_/__init__.py rename neomodel/{_sync => sync_}/cardinality.py (98%) rename neomodel/{_sync => sync_}/core.py (99%) rename neomodel/{_sync => sync_}/match.py (99%) rename neomodel/{_sync => sync_}/path.py (95%) rename neomodel/{_sync => sync_}/relationship.py (99%) rename neomodel/{_sync => sync_}/relationship_manager.py (98%) create mode 100644 test/async_/test_alias.py rename test/{ => sync}/test_alias.py (81%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4601807f..88cf1dbe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,4 +9,4 @@ repos: name: unasync entry: bin/make-unasync language: system - files: "^(neomodel/_async|test/async_)/.*" \ No newline at end of file + files: "^(neomodel/async_|test/async_)/.*" \ No newline at end of file diff --git a/bin/make-unasync b/bin/make-unasync index 33ccbf90..e56df09e 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -11,8 +11,10 @@ from pathlib import Path import unasync ROOT_DIR = Path(__file__).parents[1].absolute() -ASYNC_DIR = ROOT_DIR / "neomodel" / "_async" -SYNC_DIR = ROOT_DIR / "neomodel" / "_sync" +ASYNC_DIR = ROOT_DIR / "neomodel" / "async_" +SYNC_DIR = ROOT_DIR / "neomodel" / "sync_" +ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" +SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync" UNASYNC_SUFFIX = ".unasync" @@ -202,9 +204,9 @@ class CustomRule(unasync.Rule): def apply_unasync(files): """Generate sync code from async code.""" - additional_main_replacements = {"adb": "db", "_async": "_sync"} + additional_main_replacements = {"adb": "db", "async_": "sync_"} additional_test_replacements = { - "_async": "_sync", + "async_": "sync_", "adb": "db", "mark_async_test": "mark_sync_test", "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture", @@ -215,6 +217,11 @@ def apply_unasync(files): todir=str(SYNC_DIR), additional_replacements=additional_main_replacements, ), + CustomRule( + fromdir=str(ASYNC_CONTRIB_DIR), + todir=str(SYNC_CONTRIB_DIR), + additional_replacements=additional_main_replacements, + ), CustomRule( fromdir=str(ASYNC_INTEGRATION_TEST_DIR), todir=str(SYNC_INTEGRATION_TEST_DIR), @@ -224,6 +231,7 @@ def apply_unasync(files): if not files: paths = list(ASYNC_DIR.rglob("*")) + paths += list(ASYNC_CONTRIB_DIR.rglob("*")) paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*")) else: paths = [ROOT_DIR / Path(f) for f in files] diff --git a/doc/source/index.rst b/doc/source/index.rst index ec3372c9..397e838f 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -59,6 +59,8 @@ Contents configuration extending module_documentation + module_documentation_sync + module_documentation_async Indices and tables ================== diff --git a/doc/source/module_documentation.rst b/doc/source/module_documentation.rst index 16a4acf2..937060a5 100644 --- a/doc/source/module_documentation.rst +++ b/doc/source/module_documentation.rst @@ -1,24 +1,6 @@ -===================== -Modules documentation -===================== - -Database -======== -.. module:: neomodel.util -.. autoclass:: neomodel.util.Database - :members: - :undoc-members: - -Core -==== -.. automodule:: neomodel.core - :members: - -.. _semistructurednode_doc: - -``SemiStructuredNode`` ----------------------- -.. autoclass:: neomodel.contrib.SemiStructuredNode +========================== +Async/sync independent API +========================== Properties ========== @@ -32,43 +14,6 @@ Spatial Properties & Datatypes :members: :show-inheritance: -Relationships -============= -.. automodule:: neomodel.relationship - :members: - :show-inheritance: - -.. automodule:: neomodel.relationship_manager - :members: - :show-inheritance: - -.. automodule:: neomodel.cardinality - :members: - :show-inheritance: - -Paths -===== - -.. automodule:: neomodel.path - :members: - :show-inheritance: - - - - -Match -===== -.. module:: neomodel.match -.. autoclass:: neomodel.match.BaseSet - :members: - :undoc-members: -.. autoclass:: neomodel.match.NodeSet - :members: - :undoc-members: -.. autoclass:: neomodel.match.Traversal - :members: - :undoc-members: - Exceptions ========== @@ -78,16 +23,21 @@ Exceptions :undoc-members: :show-inheritance: + Scripts ======= -.. automodule:: neomodel.scripts.neomodel_install_labels +.. automodule:: neomodel.scripts.neomodel_inspect_database :members: :undoc-members: :show-inheritance: -.. automodule:: neomodel.scripts.neomodel_remove_labels +.. automodule:: neomodel.scripts.neomodel_install_labels :members: :undoc-members: :show-inheritance: +.. automodule:: neomodel.scripts.neomodel_remove_labels + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/doc/source/module_documentation_async.rst b/doc/source/module_documentation_async.rst new file mode 100644 index 00000000..09ce3e2b --- /dev/null +++ b/doc/source/module_documentation_async.rst @@ -0,0 +1,49 @@ +======================= +Async API Documentation +======================= + +Core +==== +.. automodule:: neomodel.async_.core + :members: + +.. _semistructurednode_doc: + +``AsyncSemiStructuredNode`` +--------------------------- +.. autoclass:: neomodel.contrib.AsyncSemiStructuredNode + +Relationships +============= +.. automodule:: neomodel.async_.relationship + :members: + :show-inheritance: + +.. automodule:: neomodel.async_.relationship_manager + :members: + :show-inheritance: + +.. automodule:: neomodel.async_.cardinality + :members: + :show-inheritance: + +Paths +===== + +.. automodule:: neomodel.async_.path + :members: + :show-inheritance: + +Match +===== +.. module:: neomodel.async_.match +.. autoclass:: neomodel.async_.match.AsyncBaseSet + :members: + :undoc-members: +.. autoclass:: neomodel.async_.match.AsyncNodeSet + :members: + :undoc-members: +.. autoclass:: neomodel.async_.match.AsyncTraversal + :members: + :undoc-members: + diff --git a/doc/source/module_documentation_sync.rst b/doc/source/module_documentation_sync.rst new file mode 100644 index 00000000..e39485f2 --- /dev/null +++ b/doc/source/module_documentation_sync.rst @@ -0,0 +1,49 @@ +================= +API Documentation +================= + +Core +==== +.. automodule:: neomodel.sync_.core + :members: + +.. _semistructurednode_doc: + +``SemiStructuredNode`` +--------------------------- +.. autoclass:: neomodel.contrib.SemiStructuredNode + +Relationships +============= +.. automodule:: neomodel.sync_.relationship + :members: + :show-inheritance: + +.. automodule:: neomodel.sync_.relationship_manager + :members: + :show-inheritance: + +.. automodule:: neomodel.sync_.cardinality + :members: + :show-inheritance: + +Paths +===== + +.. automodule:: neomodel.sync_.path + :members: + :show-inheritance: + +Match +===== +.. module:: neomodel.sync_.match +.. autoclass:: neomodel.sync_.match.BaseSet + :members: + :undoc-members: +.. autoclass:: neomodel.sync_.match.NodeSet + :members: + :undoc-members: +.. autoclass:: neomodel.sync_.match.Traversal + :members: + :undoc-members: + diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 21d6c668..4da0c3eb 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,12 +1,12 @@ # pep8: noqa # TODO : Check imports here -from neomodel._async.cardinality import ( +from neomodel.async_.cardinality import ( AsyncOne, AsyncOneOrMore, AsyncZeroOrMore, AsyncZeroOrOne, ) -from neomodel._async.core import ( +from neomodel.async_.core import ( AsyncStructuredNode, change_neo4j_password, clear_neo4j_database, @@ -16,28 +16,16 @@ install_labels, remove_all_labels, ) -from neomodel._async.match import AsyncNodeSet, AsyncTraversal -from neomodel._async.path import AsyncNeomodelPath -from neomodel._async.relationship import AsyncStructuredRel -from neomodel._async.relationship_manager import ( +from neomodel.async_.match import AsyncNodeSet, AsyncTraversal +from neomodel.async_.path import AsyncNeomodelPath +from neomodel.async_.relationship import AsyncStructuredRel +from neomodel.async_.relationship_manager import ( AsyncRelationship, AsyncRelationshipDefinition, AsyncRelationshipFrom, AsyncRelationshipManager, AsyncRelationshipTo, ) -from neomodel._sync.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne -from neomodel._sync.core import StructuredNode -from neomodel._sync.match import NodeSet, Traversal -from neomodel._sync.path import NeomodelPath -from neomodel._sync.relationship import StructuredRel -from neomodel._sync.relationship_manager import ( - Relationship, - RelationshipDefinition, - RelationshipFrom, - RelationshipManager, - RelationshipTo, -) from neomodel.exceptions import * from neomodel.match_q import Q # noqa from neomodel.properties import ( @@ -56,6 +44,18 @@ StringProperty, UniqueIdProperty, ) +from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne +from neomodel.sync_.core import StructuredNode +from neomodel.sync_.match import NodeSet, Traversal +from neomodel.sync_.path import NeomodelPath +from neomodel.sync_.relationship import StructuredRel +from neomodel.sync_.relationship_manager import ( + Relationship, + RelationshipDefinition, + RelationshipFrom, + RelationshipManager, + RelationshipTo, +) from neomodel.util import EITHER, INCOMING, OUTGOING __author__ = "Robin Edwards" diff --git a/neomodel/_async/__init__.py b/neomodel/_async/__init__.py deleted file mode 100644 index 95bbd58a..00000000 --- a/neomodel/_async/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# from neomodel._async.core import adb diff --git a/neomodel/_sync/__init__.py b/neomodel/_sync/__init__.py deleted file mode 100644 index 95bbd58a..00000000 --- a/neomodel/_sync/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# from neomodel._async.core import adb diff --git a/neomodel/async_/__init__.py b/neomodel/async_/__init__.py new file mode 100644 index 00000000..f1d519e8 --- /dev/null +++ b/neomodel/async_/__init__.py @@ -0,0 +1 @@ +# from neomodel.async_.core import adb diff --git a/neomodel/_async/cardinality.py b/neomodel/async_/cardinality.py similarity index 98% rename from neomodel/_async/cardinality.py rename to neomodel/async_/cardinality.py index 7b1f5cf0..0c3b02cf 100644 --- a/neomodel/_async/cardinality.py +++ b/neomodel/async_/cardinality.py @@ -1,4 +1,4 @@ -from neomodel._async.relationship_manager import ( # pylint:disable=unused-import +from neomodel.async_.relationship_manager import ( # pylint:disable=unused-import AsyncRelationshipManager, AsyncZeroOrMore, ) diff --git a/neomodel/_async/core.py b/neomodel/async_/core.py similarity index 99% rename from neomodel/_async/core.py rename to neomodel/async_/core.py index 348f23a7..93209e7f 100644 --- a/neomodel/_async/core.py +++ b/neomodel/async_/core.py @@ -344,7 +344,7 @@ def _object_resolution(self, object_to_resolve): return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) if isinstance(object_to_resolve, Path): - from neomodel._async.path import AsyncNeomodelPath + from neomodel.async_.path import AsyncNeomodelPath return AsyncNeomodelPath(object_to_resolve) @@ -1115,7 +1115,7 @@ def nodes(cls): :return: NodeSet :rtype: NodeSet """ - from neomodel._async.match import AsyncNodeSet + from neomodel.async_.match import AsyncNodeSet return AsyncNodeSet(cls) @@ -1178,7 +1178,7 @@ def _build_merge_query( "No relation_type is specified on provided relationship" ) - from neomodel._async.match import _rel_helper + from neomodel.async_.match import _rel_helper query_params["source_id"] = relationship.source.element_id query = f"MATCH (source:{relationship.source.__label__}) WHERE {adb.get_id_method()}(source) = $source_id\n " diff --git a/neomodel/_async/match.py b/neomodel/async_/match.py similarity index 99% rename from neomodel/_async/match.py rename to neomodel/async_/match.py index 71d5a653..40954fe5 100644 --- a/neomodel/_async/match.py +++ b/neomodel/async_/match.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Optional -from neomodel._async.core import AsyncStructuredNode, adb +from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty diff --git a/neomodel/_async/path.py b/neomodel/async_/path.py similarity index 94% rename from neomodel/_async/path.py rename to neomodel/async_/path.py index cf53ed9d..6128347e 100644 --- a/neomodel/_async/path.py +++ b/neomodel/async_/path.py @@ -1,7 +1,7 @@ from neo4j.graph import Path -from neomodel._async.core import adb -from neomodel._async.relationship import AsyncStructuredRel +from neomodel.async_.core import adb +from neomodel.async_.relationship import AsyncStructuredRel class AsyncNeomodelPath(Path): diff --git a/neomodel/_async/relationship.py b/neomodel/async_/relationship.py similarity index 99% rename from neomodel/_async/relationship.py rename to neomodel/async_/relationship.py index 637f35be..355003d4 100644 --- a/neomodel/_async/relationship.py +++ b/neomodel/async_/relationship.py @@ -1,4 +1,4 @@ -from neomodel._async.core import adb +from neomodel.async_.core import adb from neomodel.hooks import hooks from neomodel.properties import Property, PropertyManager diff --git a/neomodel/_async/relationship_manager.py b/neomodel/async_/relationship_manager.py similarity index 99% rename from neomodel/_async/relationship_manager.py rename to neomodel/async_/relationship_manager.py index 3bc587c3..2a1b95e2 100644 --- a/neomodel/_async/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -3,14 +3,14 @@ import sys from importlib import import_module -from neomodel._async.core import adb -from neomodel._async.match import ( +from neomodel.async_.core import adb +from neomodel.async_.match import ( AsyncNodeSet, AsyncTraversal, _rel_helper, _rel_merge_helper, ) -from neomodel._async.relationship import AsyncStructuredRel +from neomodel.async_.relationship import AsyncStructuredRel from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.util import ( EITHER, diff --git a/neomodel/contrib/__init__.py b/neomodel/contrib/__init__.py index 15a59660..a852965d 100644 --- a/neomodel/contrib/__init__.py +++ b/neomodel/contrib/__init__.py @@ -1 +1,2 @@ -from neomodel.semi_structured import SemiStructuredNode +from neomodel.contrib.async_.semi_structured import AsyncSemiStructuredNode +from neomodel.contrib.sync_.semi_structured import SemiStructuredNode diff --git a/neomodel/contrib/async_/semi_structured.py b/neomodel/contrib/async_/semi_structured.py new file mode 100644 index 00000000..c333ae0e --- /dev/null +++ b/neomodel/contrib/async_/semi_structured.py @@ -0,0 +1,64 @@ +from neomodel.async_.core import AsyncStructuredNode +from neomodel.exceptions import DeflateConflict, InflateConflict +from neomodel.util import _get_node_properties + + +class AsyncSemiStructuredNode(AsyncStructuredNode): + """ + A base class allowing properties to be stored on a node that aren't + specified in its definition. Conflicting properties are signaled with the + :class:`DeflateConflict` exception:: + + class Person(AsyncSemiStructuredNode): + name = StringProperty() + age = IntegerProperty() + + def hello(self): + print("Hi my names " + self.name) + + tim = await Person(name='Tim', age=8, weight=11).save() + tim.hello = "Hi" + await tim.save() # DeflateConflict + """ + + __abstract_node__ = True + + @classmethod + def inflate(cls, node): + # support lazy loading + if isinstance(node, str) or isinstance(node, int): + snode = cls() + snode.element_id_property = node + else: + props = {} + node_properties = {} + for key, prop in cls.__all_properties__: + node_properties = _get_node_properties(node) + if key in node_properties: + props[key] = prop.inflate(node_properties[key], node) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + # handle properties not defined on the class + for free_key in (x for x in node_properties if x not in props): + if hasattr(cls, free_key): + raise InflateConflict( + cls, free_key, node_properties[free_key], node.element_id + ) + props[free_key] = node_properties[free_key] + + snode = cls(**props) + snode.element_id_property = node.element_id + + return snode + + @classmethod + def deflate(cls, node_props, obj=None, skip_empty=False): + deflated = super().deflate(node_props, obj, skip_empty=skip_empty) + for key in [k for k in node_props if k not in deflated]: + if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): + raise DeflateConflict(cls, key, deflated[key], obj.element_id) + + node_props.update(deflated) + return node_props diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py similarity index 95% rename from neomodel/contrib/semi_structured.py rename to neomodel/contrib/sync_/semi_structured.py index 869763dd..86a5a140 100644 --- a/neomodel/contrib/semi_structured.py +++ b/neomodel/contrib/sync_/semi_structured.py @@ -1,9 +1,9 @@ -from neomodel._async.core import AsyncStructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict +from neomodel.sync_.core import StructuredNode from neomodel.util import _get_node_properties -class SemiStructuredNode(AsyncStructuredNode): +class SemiStructuredNode(StructuredNode): """ A base class allowing properties to be stored on a node that aren't specified in its definition. Conflicting properties are signaled with the diff --git a/neomodel/integration/numpy.py b/neomodel/integration/numpy.py index 5dc6da80..14bae3df 100644 --- a/neomodel/integration/numpy.py +++ b/neomodel/integration/numpy.py @@ -7,7 +7,7 @@ Example: - >>> from neomodel._async import db + >>> from neomodel.async_ import db >>> from neomodel.integration.numpy import to_nparray >>> db.set_connection('bolt://neo4j:secret@localhost:7687') >>> df = to_nparray(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name")) diff --git a/neomodel/integration/pandas.py b/neomodel/integration/pandas.py index 845c8e50..2f809ade 100644 --- a/neomodel/integration/pandas.py +++ b/neomodel/integration/pandas.py @@ -7,7 +7,7 @@ Example: - >>> from neomodel._async import db + >>> from neomodel.async_ import db >>> from neomodel.integration.pandas import to_dataframe >>> db.set_connection('bolt://neo4j:secret@localhost:7687') >>> df = to_dataframe(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name")) diff --git a/neomodel/properties.py b/neomodel/properties.py index df80a47e..51b59449 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -67,7 +67,7 @@ def __init__(self, **kwargs): @property def __properties__(self): - from neomodel._async.relationship_manager import AsyncRelationshipManager + from neomodel.async_.relationship_manager import AsyncRelationshipManager return dict( (name, value) @@ -101,7 +101,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): - from neomodel._async.relationship_manager import AsyncRelationshipDefinition + from neomodel.async_.relationship_manager import AsyncRelationshipDefinition props = {} for baseclass in reversed(cls.__mro__): @@ -467,7 +467,7 @@ def deflate(self, value): class DateTimeFormatProperty(Property): """ - Store a datetime by custome format + Store a datetime by custom format :param default_now: If ``True``, the creation time (Local) will be used as default. Defaults to ``False``. :param format: Date format string, default is %Y-%m-%d diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index 35b99d42..f5254e53 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -17,6 +17,8 @@ If a file is specified, the tool will write the class definitions to that file. If no file is specified, the tool will print the class definitions to stdout. + + Note : this script only has a synchronous mode. options: -h, --help show this help message and exit @@ -33,7 +35,7 @@ import textwrap from os import environ -from neomodel._sync.core import db +from neomodel.sync_.core import db IMPORTS = [] diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index c0d9c82a..8e553396 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -32,7 +32,7 @@ from importlib import import_module from os import environ, path -from neomodel._sync.core import db +from neomodel.sync_.core import db def load_python_module_or_file(name): diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 2272c7fa..14199b0b 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -27,7 +27,7 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from os import environ -from neomodel._sync.core import db +from neomodel.sync_.core import db def main(): diff --git a/neomodel/sync_/__init__.py b/neomodel/sync_/__init__.py new file mode 100644 index 00000000..f1d519e8 --- /dev/null +++ b/neomodel/sync_/__init__.py @@ -0,0 +1 @@ +# from neomodel.async_.core import adb diff --git a/neomodel/_sync/cardinality.py b/neomodel/sync_/cardinality.py similarity index 98% rename from neomodel/_sync/cardinality.py rename to neomodel/sync_/cardinality.py index 89fe0b30..b8b7b10e 100644 --- a/neomodel/_sync/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,4 +1,4 @@ -from neomodel._sync.relationship_manager import ( # pylint:disable=unused-import +from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) diff --git a/neomodel/_sync/core.py b/neomodel/sync_/core.py similarity index 99% rename from neomodel/_sync/core.py rename to neomodel/sync_/core.py index 56f69008..13144f42 100644 --- a/neomodel/_sync/core.py +++ b/neomodel/sync_/core.py @@ -344,7 +344,7 @@ def _object_resolution(self, object_to_resolve): return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) if isinstance(object_to_resolve, Path): - from neomodel._sync.path import NeomodelPath + from neomodel.sync_.path import NeomodelPath return NeomodelPath(object_to_resolve) @@ -1115,7 +1115,7 @@ def nodes(cls): :return: NodeSet :rtype: NodeSet """ - from neomodel._sync.match import NodeSet + from neomodel.sync_.match import NodeSet return NodeSet(cls) @@ -1178,7 +1178,7 @@ def _build_merge_query( "No relation_type is specified on provided relationship" ) - from neomodel._sync.match import _rel_helper + from neomodel.sync_.match import _rel_helper query_params["source_id"] = relationship.source.element_id query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " diff --git a/neomodel/_sync/match.py b/neomodel/sync_/match.py similarity index 99% rename from neomodel/_sync/match.py rename to neomodel/sync_/match.py index 3926f289..206754ec 100644 --- a/neomodel/_sync/match.py +++ b/neomodel/sync_/match.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Optional -from neomodel._sync.core import StructuredNode, db +from neomodel.sync_.core import StructuredNode, db from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty diff --git a/neomodel/_sync/path.py b/neomodel/sync_/path.py similarity index 95% rename from neomodel/_sync/path.py rename to neomodel/sync_/path.py index 6848e903..62a49fe7 100644 --- a/neomodel/_sync/path.py +++ b/neomodel/sync_/path.py @@ -1,7 +1,7 @@ from neo4j.graph import Path -from neomodel._sync.core import db -from neomodel._sync.relationship import StructuredRel +from neomodel.sync_.core import db +from neomodel.sync_.relationship import StructuredRel class NeomodelPath(Path): diff --git a/neomodel/_sync/relationship.py b/neomodel/sync_/relationship.py similarity index 99% rename from neomodel/_sync/relationship.py rename to neomodel/sync_/relationship.py index 63096f5c..4e5a7a71 100644 --- a/neomodel/_sync/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,4 +1,4 @@ -from neomodel._sync.core import db +from neomodel.sync_.core import db from neomodel.hooks import hooks from neomodel.properties import Property, PropertyManager diff --git a/neomodel/_sync/relationship_manager.py b/neomodel/sync_/relationship_manager.py similarity index 98% rename from neomodel/_sync/relationship_manager.py rename to neomodel/sync_/relationship_manager.py index fb23b315..64512c56 100644 --- a/neomodel/_sync/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -3,9 +3,14 @@ import sys from importlib import import_module -from neomodel._sync.core import db -from neomodel._sync.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper -from neomodel._sync.relationship import StructuredRel +from neomodel.sync_.core import db +from neomodel.sync_.match import ( + NodeSet, + Traversal, + _rel_helper, + _rel_merge_helper, +) +from neomodel.sync_.relationship import StructuredRel from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.util import ( EITHER, diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py index a8c5eead..77f9c5fb 100644 --- a/test/_async_compat/mark_decorator.py +++ b/test/_async_compat/mark_decorator.py @@ -1,8 +1,8 @@ import pytest -import pytest_asyncio +import pytestasync_io mark_async_test = pytest.mark.asyncio -mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True) +mark_async_session_auto_fixture = pytestasync_io.fixture(scope="session", autouse=True) mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True) diff --git a/test/async_/conftest.py b/test/async_/conftest.py index e82da39b..0819380a 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -1,12 +1,12 @@ import asyncio import os import warnings -from test._async_compat import mark_async_session_auto_fixture +from test.async__compat import mark_async_session_auto_fixture import pytest from neomodel import config -from neomodel._async.core import adb +from neomodel.async_.core import adb @mark_async_session_auto_fixture diff --git a/test/async_/test_alias.py b/test/async_/test_alias.py new file mode 100644 index 00000000..5218d41c --- /dev/null +++ b/test/async_/test_alias.py @@ -0,0 +1,34 @@ +from test.async__compat import mark_async_test + +from neomodel import AliasProperty, AsyncStructuredNode, StringProperty + + +class MagicProperty(AliasProperty): + def setup(self): + self.owner.setup_hook_called = True + + +class AliasTestNode(AsyncStructuredNode): + name = StringProperty(unique_index=True) + full_name = AliasProperty(to="name") + long_name = MagicProperty(to="name") + + +@mark_async_test +async def test_property_setup_hook(): + tim = await AliasTestNode(long_name="tim").save() + assert AliasTestNode.setup_hook_called + assert tim.name == "tim" + + +@mark_async_test +async def test_alias(): + jim = await AliasTestNode(full_name="Jim").save() + assert jim.name == "Jim" + assert jim.full_name == "Jim" + assert "full_name" not in AliasTestNode.deflate(jim.__properties__) + jim = await AliasTestNode.nodes.get(full_name="Jim") + assert jim + assert jim.name == "Jim" + assert jim.full_name == "Jim" + assert "full_name" not in AliasTestNode.deflate(jim.__properties__) diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index 31aa2f68..3c909a52 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -1,5 +1,5 @@ import builtins -from test._async_compat import mark_async_test +from test.async__compat import mark_async_test import pytest from neo4j.exceptions import ClientError as CypherError @@ -7,7 +7,7 @@ from pandas import DataFrame, Series from neomodel import AsyncStructuredNode, StringProperty -from neomodel._async.core import adb +from neomodel.async_.core import adb class User2(AsyncStructuredNode): diff --git a/test/sync/conftest.py b/test/sync/conftest.py index 867906b4..3b8cf640 100644 --- a/test/sync/conftest.py +++ b/test/sync/conftest.py @@ -1,12 +1,12 @@ import asyncio import os import warnings -from test._async_compat import mark_sync_session_auto_fixture +from test._compat import mark_sync_session_auto_fixture import pytest from neomodel import config -from neomodel._sync.core import db +from neomodel.sync_.core import db @mark_sync_session_auto_fixture diff --git a/test/test_alias.py b/test/sync/test_alias.py similarity index 81% rename from test/test_alias.py rename to test/sync/test_alias.py index 6f810b03..d3f72baa 100644 --- a/test/test_alias.py +++ b/test/sync/test_alias.py @@ -1,4 +1,6 @@ -from neomodel import AliasProperty, AsyncStructuredNode, StringProperty +from test._compat import mark_sync_test + +from neomodel import AliasProperty, StructuredNode, StringProperty class MagicProperty(AliasProperty): @@ -6,18 +8,20 @@ def setup(self): self.owner.setup_hook_called = True -class AliasTestNode(AsyncStructuredNode): +class AliasTestNode(StructuredNode): name = StringProperty(unique_index=True) full_name = AliasProperty(to="name") long_name = MagicProperty(to="name") +@mark_sync_test def test_property_setup_hook(): tim = AliasTestNode(long_name="tim").save() assert AliasTestNode.setup_hook_called assert tim.name == "tim" +@mark_sync_test def test_alias(): jim = AliasTestNode(full_name="Jim").save() assert jim.name == "Jim" diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index a76e3ba8..5fefead3 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -1,5 +1,5 @@ import builtins -from test._async_compat import mark_sync_test +from test._compat import mark_sync_test import pytest from neo4j.exceptions import ClientError as CypherError @@ -7,7 +7,7 @@ from pandas import DataFrame, Series from neomodel import StructuredNode, StringProperty -from neomodel._sync.core import db +from neomodel.sync_.core import db class User2(StructuredNode): diff --git a/test/test_database_management.py b/test/test_database_management.py index 9791d3a4..af1f6d2e 100644 --- a/test/test_database_management.py +++ b/test/test_database_management.py @@ -8,7 +8,7 @@ IntegerProperty, StringProperty, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb class City(AsyncStructuredNode): diff --git a/test/test_dbms_awareness.py b/test/test_dbms_awareness.py index 02fee179..dc2bf01b 100644 --- a/test/test_dbms_awareness.py +++ b/test/test_dbms_awareness.py @@ -1,6 +1,6 @@ from pytest import mark -from neomodel._async.core import adb +from neomodel.async_.core import adb from neomodel.util import version_tag_to_integer diff --git a/test/test_driver_options.py b/test/test_driver_options.py index e2fba00f..12123931 100644 --- a/test/test_driver_options.py +++ b/test/test_driver_options.py @@ -2,7 +2,7 @@ from neo4j.exceptions import ClientError from pytest import raises -from neomodel._async.core import adb +from neomodel.async_.core import adb from neomodel.exceptions import FeatureNotSupported diff --git a/test/test_indexing.py b/test/test_indexing.py index 5f1df506..88311679 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -7,7 +7,7 @@ StringProperty, UniqueProperty, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb from neomodel.exceptions import ConstraintValidationFailed diff --git a/test/test_label_drop.py b/test/test_label_drop.py index aadbad0c..1e8f7112 100644 --- a/test/test_label_drop.py +++ b/test/test_label_drop.py @@ -1,7 +1,7 @@ from neo4j.exceptions import ClientError from neomodel import AsyncStructuredNode, StringProperty, config -from neomodel._async.core import adb +from neomodel.async_.core import adb config.AUTO_INSTALL_LABELS = True diff --git a/test/test_label_install.py b/test/test_label_install.py index 771eddab..0061ed71 100644 --- a/test/test_label_install.py +++ b/test/test_label_install.py @@ -8,7 +8,7 @@ UniqueIdProperty, config, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported config.AUTO_INSTALL_LABELS = False diff --git a/test/test_match_api.py b/test/test_match_api.py index 8cbe3e1a..618528cf 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -13,7 +13,7 @@ Q, StringProperty, ) -from neomodel._async.match import ( +from neomodel.async_.match import ( AsyncNodeSet, AsyncQueryBuilder, AsyncTraversal, diff --git a/test/test_migration_neo4j_5.py b/test/test_migration_neo4j_5.py index ee4c3ed2..c61dd312 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/test_migration_neo4j_5.py @@ -7,7 +7,7 @@ IntegerProperty, StringProperty, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb class Album(AsyncStructuredNode): diff --git a/test/test_models.py b/test/test_models.py index 20e2bccc..b781bb92 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -11,7 +11,7 @@ IntegerProperty, StringProperty, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb from neomodel.exceptions import RequiredProperty, UniqueProperty diff --git a/test/test_relationships.py b/test/test_relationships.py index 69196eaf..fa6ff01d 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -11,7 +11,7 @@ Q, StringProperty, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb class PersonWithRels(AsyncStructuredNode): diff --git a/test/test_scripts.py b/test/test_scripts.py index d054ddb5..77dee66e 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -9,7 +9,7 @@ StringProperty, config, ) -from neomodel._async.core import adb +from neomodel.async_.core import adb class ScriptsTestRel(AsyncStructuredRel): diff --git a/test/test_transactions.py b/test/test_transactions.py index 4bdec8af..83623821 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -4,7 +4,7 @@ from pytest import raises from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty -from neomodel._async.core import adb +from neomodel.async_.core import adb class APerson(AsyncStructuredNode): From c2883241502e1d7dc4b34e597d46b1a7ee3f5e18 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 22 Dec 2023 16:13:47 +0100 Subject: [PATCH 16/73] Fix async_ replacements --- test/_async_compat/mark_decorator.py | 4 ++-- test/async_/conftest.py | 2 +- test/async_/test_alias.py | 2 +- test/async_/test_cypher.py | 2 +- test/sync/conftest.py | 2 +- test/sync/test_alias.py | 4 ++-- test/sync/test_cypher.py | 4 ++-- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py index 77f9c5fb..a8c5eead 100644 --- a/test/_async_compat/mark_decorator.py +++ b/test/_async_compat/mark_decorator.py @@ -1,8 +1,8 @@ import pytest -import pytestasync_io +import pytest_asyncio mark_async_test = pytest.mark.asyncio -mark_async_session_auto_fixture = pytestasync_io.fixture(scope="session", autouse=True) +mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True) mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True) diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 0819380a..92b65ff2 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -1,7 +1,7 @@ import asyncio import os import warnings -from test.async__compat import mark_async_session_auto_fixture +from test._async_compat import mark_async_session_auto_fixture import pytest diff --git a/test/async_/test_alias.py b/test/async_/test_alias.py index 5218d41c..3b0f6529 100644 --- a/test/async_/test_alias.py +++ b/test/async_/test_alias.py @@ -1,4 +1,4 @@ -from test.async__compat import mark_async_test +from test._async_compat import mark_async_test from neomodel import AliasProperty, AsyncStructuredNode, StringProperty diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index 3c909a52..eee735ea 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -1,5 +1,5 @@ import builtins -from test.async__compat import mark_async_test +from test._async_compat import mark_async_test import pytest from neo4j.exceptions import ClientError as CypherError diff --git a/test/sync/conftest.py b/test/sync/conftest.py index 3b8cf640..356d9438 100644 --- a/test/sync/conftest.py +++ b/test/sync/conftest.py @@ -1,7 +1,7 @@ import asyncio import os import warnings -from test._compat import mark_sync_session_auto_fixture +from test._async_compat import mark_sync_session_auto_fixture import pytest diff --git a/test/sync/test_alias.py b/test/sync/test_alias.py index d3f72baa..f266eb82 100644 --- a/test/sync/test_alias.py +++ b/test/sync/test_alias.py @@ -1,6 +1,6 @@ -from test._compat import mark_sync_test +from test._async_compat import mark_sync_test -from neomodel import AliasProperty, StructuredNode, StringProperty +from neomodel import AliasProperty, StringProperty, StructuredNode class MagicProperty(AliasProperty): diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 5fefead3..5cd431d8 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -1,12 +1,12 @@ import builtins -from test._compat import mark_sync_test +from test._async_compat import mark_sync_test import pytest from neo4j.exceptions import ClientError as CypherError from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StructuredNode, StringProperty +from neomodel import StringProperty, StructuredNode from neomodel.sync_.core import db From cd72533c81d112767031868b13352c7b5b0a802e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 29 Dec 2023 09:42:46 +0100 Subject: [PATCH 17/73] Some more fixes and test --- neomodel/async_/core.py | 4 +-- neomodel/async_/match.py | 17 +++++++-- test/async_/conftest.py | 3 +- test/{ => async_}/test_batch.py | 62 +++++++++++++++++++++------------ 4 files changed, 57 insertions(+), 29 deletions(-) rename test/{ => async_}/test_batch.py (59%) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 93209e7f..ab8faeda 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -1295,7 +1295,7 @@ async def create_or_update(cls, *props, **kwargs): # fetch and build instance for each result results = await adb.cypher_query(query, params) - return [cls.inflate(r[0]) async for r in results[0]] + return [cls.inflate(r[0]) for r in results[0]] async def cypher(self, query, params=None): """ @@ -1364,7 +1364,7 @@ async def get_or_create(cls, *props, **kwargs): # fetch and build instance for each result results = await adb.cypher_query(query, params) - return [cls.inflate(r[0]) async for r in results[0]] + return [cls.inflate(r[0]) for r in results[0]] @classmethod def inflate(cls, node): diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 40954fe5..17f40162 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -726,6 +726,15 @@ class AsyncBaseSet: query_cls = AsyncQueryBuilder + async def all(self, lazy=False): + """ + Return all nodes belonging to the set + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :return: list of nodes + :rtype: list + """ + return await self.query_cls(self).build_ast()._execute(lazy) + async def __aiter__(self): async for i in await self.query_cls(self).build_ast()._execute(): yield i @@ -734,10 +743,12 @@ async def __len__(self): return await self.query_cls(self).build_ast()._count() async def __abool__(self): - return bool(await self.query_cls(self).build_ast()._count() > 0) + _count = await self.query_cls(self).build_ast()._count() + return _count > 0 - async def __nonzero__(self): - return bool(await self.query_cls(self).build_ast()._count() > 0) + async def __anonzero__(self): + _count = await self.query_cls(self).build_ast()._count() + return _count > 0 def __contains__(self, obj): if isinstance(obj, AsyncStructuredNode): diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 92b65ff2..1b9b76a9 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -23,7 +23,6 @@ async def setup_neo4j_session(request, event_loop): config.DATABASE_URL = os.environ.get( "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" ) - config.AUTO_INSTALL_LABELS = True # Clear the database if required database_is_populated, _ = await adb.cypher_query( @@ -36,6 +35,8 @@ async def setup_neo4j_session(request, event_loop): await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + await adb.install_all_labels() + await adb.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) diff --git a/test/test_batch.py b/test/async_/test_batch.py similarity index 59% rename from test/test_batch.py rename to test/async_/test_batch.py index 8280805a..b465477f 100644 --- a/test/test_batch.py +++ b/test/async_/test_batch.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_async_test + from pytest import raises from neomodel import ( @@ -20,12 +22,17 @@ class UniqueUser(AsyncStructuredNode): age = IntegerProperty() -def test_unique_id_property_batch(): - users = UniqueUser.create({"name": "bob", "age": 2}, {"name": "ben", "age": 3}) +@mark_async_test +async def test_unique_id_property_batch(): + users = await UniqueUser.create( + {"name": "bob", "age": 2}, {"name": "ben", "age": 3} + ) assert users[0].uid != users[1].uid - users = UniqueUser.get_or_create({"uid": users[0].uid}, {"name": "bill", "age": 4}) + users = await UniqueUser.get_or_create( + {"uid": users[0].uid}, {"name": "bill", "age": 4} + ) assert users[0].name == "bob" assert users[1].uid @@ -36,8 +43,9 @@ class Customer(AsyncStructuredNode): age = IntegerProperty(index=True) -def test_batch_create(): - users = Customer.create( +@mark_async_test +async def test_batch_create(): + users = await Customer.create( {"email": "jim1@aol.com", "age": 11}, {"email": "jim2@aol.com", "age": 7}, {"email": "jim3@aol.com", "age": 9}, @@ -48,11 +56,12 @@ def test_batch_create(): assert users[0].age == 11 assert users[1].age == 7 assert users[1].email == "jim2@aol.com" - assert Customer.nodes.get(email="jim1@aol.com") + assert await Customer.nodes.get(email="jim1@aol.com") -def test_batch_create_or_update(): - users = Customer.create_or_update( +@mark_async_test +async def test_batch_create_or_update(): + users = await Customer.create_or_update( {"email": "merge1@aol.com", "age": 11}, {"email": "merge2@aol.com"}, {"email": "merge3@aol.com", "age": 1}, @@ -60,35 +69,39 @@ def test_batch_create_or_update(): ) assert len(users) == 4 assert users[1] == users[3] - assert Customer.nodes.get(email="merge1@aol.com").age == 11 + merge_1: Customer = await Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 11 - more_users = Customer.create_or_update( + more_users = await Customer.create_or_update( {"email": "merge1@aol.com", "age": 22}, {"email": "merge4@aol.com", "age": None}, ) assert len(more_users) == 2 assert users[0] == more_users[0] - assert Customer.nodes.get(email="merge1@aol.com").age == 22 + merge_1 = await Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 22 -def test_batch_validation(): +@mark_async_test +async def test_batch_validation(): # test validation in batch create with raises(DeflateError): - Customer.create( + await Customer.create( {"email": "jim1@aol.com", "age": "x"}, ) -def test_batch_index_violation(): - for u in Customer.nodes.all(): - u.delete() +@mark_async_test +async def test_batch_index_violation(): + for u in await Customer.nodes.all(): + await u.delete() - users = Customer.create( + users = await Customer.create( {"email": "jim6@aol.com", "age": 3}, ) assert users with raises(UniqueProperty): - Customer.create( + await Customer.create( {"email": "jim6@aol.com", "age": 3}, {"email": "jim7@aol.com", "age": 5}, ) @@ -107,12 +120,15 @@ class Person(AsyncStructuredNode): pets = AsyncRelationshipFrom("Dog", "owner") -def test_get_or_create_with_rel(): - bob = Person.get_or_create({"name": "Bob"})[0] - bobs_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) +@mark_async_test +async def test_get_or_create_with_rel(): + create_bob = await Person.get_or_create({"name": "Bob"}) + bob = create_bob[0] + bobs_gizmo = await Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) - tim = Person.get_or_create({"name": "Tim"})[0] - tims_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) + create_tim = await Person.get_or_create({"name": "Tim"}) + tim = create_tim[0] + tims_gizmo = await Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) # not the same gizmo assert bobs_gizmo[0] != tims_gizmo[0] From c9b877e5c3b3e9e2fd9d06eda9cdad8a4ca77001 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 29 Dec 2023 14:34:43 +0100 Subject: [PATCH 18/73] Fix donder methods --- bin/make-unasync | 9 ++- neomodel/async_/match.py | 17 +++-- neomodel/sync_/match.py | 32 +++++++-- test/async_/test_batch.py | 2 +- test/sync/conftest.py | 3 +- test/sync/test_batch.py | 134 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 183 insertions(+), 14 deletions(-) create mode 100644 test/sync/test_batch.py diff --git a/bin/make-unasync b/bin/make-unasync index e56df09e..4f6a8ec2 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -204,9 +204,16 @@ class CustomRule(unasync.Rule): def apply_unasync(files): """Generate sync code from async code.""" - additional_main_replacements = {"adb": "db", "async_": "sync_"} + additional_main_replacements = { + "adb": "db", + "async_": "sync_", + "check_bool": "__bool__", + "check_non_zero": "__nonzero__", + } additional_test_replacements = { "async_": "sync_", + "check_bool": "__bool__", + "check_non_zero": "__nonzero__", "adb": "db", "mark_async_test": "mark_sync_test", "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture", diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 17f40162..7f092577 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -742,13 +742,22 @@ async def __aiter__(self): async def __len__(self): return await self.query_cls(self).build_ast()._count() - async def __abool__(self): + async def check_bool(self): + """ + Override for __bool__ dunder method. + :return: True if the set contains any nodes, False otherwise + :rtype: bool + """ _count = await self.query_cls(self).build_ast()._count() return _count > 0 - async def __anonzero__(self): - _count = await self.query_cls(self).build_ast()._count() - return _count > 0 + async def check_non_zero(self): + """ + Override for __bool__ dunder method. + :return: True if the set contains any node, False otherwise + :rtype: bool + """ + return await self.check_bool() def __contains__(self, obj): if isinstance(obj, AsyncStructuredNode): diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 206754ec..deb92f9b 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from typing import Optional -from neomodel.sync_.core import StructuredNode, db from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty +from neomodel.sync_.core import StructuredNode, db from neomodel.util import INCOMING, OUTGOING @@ -705,9 +705,7 @@ def _execute(self, lazy=False): for item in self._ast.additional_return ] query = self.build_query() - results, _ = db.cypher_query( - query, self._query_params, resolve_objects=True - ) + results, _ = db.cypher_query(query, self._query_params, resolve_objects=True) # The following is not as elegant as it could be but had to be copied from the # version prior to cypher_query with the resolve_objects capability. # It seems that certain calls are only supposed to be focusing to the first @@ -726,6 +724,15 @@ class BaseSet: query_cls = QueryBuilder + def all(self, lazy=False): + """ + Return all nodes belonging to the set + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :return: list of nodes + :rtype: list + """ + return self.query_cls(self).build_ast()._execute(lazy) + def __iter__(self): for i in self.query_cls(self).build_ast()._execute(): yield i @@ -733,11 +740,22 @@ def __iter__(self): def __len__(self): return self.query_cls(self).build_ast()._count() - def __abool__(self): - return bool(self.query_cls(self).build_ast()._count() > 0) + def __bool__(self): + """ + Override for __bool__ dunder method. + :return: True if the set contains any nodes, False otherwise + :rtype: bool + """ + _count = self.query_cls(self).build_ast()._count() + return _count > 0 def __nonzero__(self): - return bool(self.query_cls(self).build_ast()._count() > 0) + """ + Override for __bool__ dunder method. + :return: True if the set contains any node, False otherwise + :rtype: bool + """ + return self.__bool__() def __contains__(self, obj): if isinstance(obj, StructuredNode): diff --git a/test/async_/test_batch.py b/test/async_/test_batch.py index b465477f..3b5d76d1 100644 --- a/test/async_/test_batch.py +++ b/test/async_/test_batch.py @@ -107,7 +107,7 @@ async def test_batch_index_violation(): ) # not found - assert not Customer.nodes.filter(email="jim7@aol.com") + assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() class Dog(AsyncStructuredNode): diff --git a/test/sync/conftest.py b/test/sync/conftest.py index 356d9438..0f3beb8d 100644 --- a/test/sync/conftest.py +++ b/test/sync/conftest.py @@ -23,7 +23,6 @@ def setup_neo4j_session(request, event_loop): config.DATABASE_URL = os.environ.get( "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" ) - config.AUTO_INSTALL_LABELS = True # Clear the database if required database_is_populated, _ = db.cypher_query( @@ -36,6 +35,8 @@ def setup_neo4j_session(request, event_loop): db.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + db.install_all_labels() + db.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) diff --git a/test/sync/test_batch.py b/test/sync/test_batch.py new file mode 100644 index 00000000..6823d8fd --- /dev/null +++ b/test/sync/test_batch.py @@ -0,0 +1,134 @@ +from test._async_compat import mark_sync_test + +from pytest import raises + +from neomodel import ( + IntegerProperty, + RelationshipFrom, + RelationshipTo, + StringProperty, + StructuredNode, + UniqueIdProperty, + config, +) +from neomodel.exceptions import DeflateError, UniqueProperty + +config.AUTO_INSTALL_LABELS = True + + +class UniqueUser(StructuredNode): + uid = UniqueIdProperty() + name = StringProperty() + age = IntegerProperty() + + +@mark_sync_test +def test_unique_id_property_batch(): + users = UniqueUser.create( + {"name": "bob", "age": 2}, {"name": "ben", "age": 3} + ) + + assert users[0].uid != users[1].uid + + users = UniqueUser.get_or_create( + {"uid": users[0].uid}, {"name": "bill", "age": 4} + ) + + assert users[0].name == "bob" + assert users[1].uid + + +class Customer(StructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +@mark_sync_test +def test_batch_create(): + users = Customer.create( + {"email": "jim1@aol.com", "age": 11}, + {"email": "jim2@aol.com", "age": 7}, + {"email": "jim3@aol.com", "age": 9}, + {"email": "jim4@aol.com", "age": 7}, + {"email": "jim5@aol.com", "age": 99}, + ) + assert len(users) == 5 + assert users[0].age == 11 + assert users[1].age == 7 + assert users[1].email == "jim2@aol.com" + assert Customer.nodes.get(email="jim1@aol.com") + + +@mark_sync_test +def test_batch_create_or_update(): + users = Customer.create_or_update( + {"email": "merge1@aol.com", "age": 11}, + {"email": "merge2@aol.com"}, + {"email": "merge3@aol.com", "age": 1}, + {"email": "merge2@aol.com", "age": 2}, + ) + assert len(users) == 4 + assert users[1] == users[3] + merge_1: Customer = Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 11 + + more_users = Customer.create_or_update( + {"email": "merge1@aol.com", "age": 22}, + {"email": "merge4@aol.com", "age": None}, + ) + assert len(more_users) == 2 + assert users[0] == more_users[0] + merge_1 = Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 22 + + +@mark_sync_test +def test_batch_validation(): + # test validation in batch create + with raises(DeflateError): + Customer.create( + {"email": "jim1@aol.com", "age": "x"}, + ) + + +@mark_sync_test +def test_batch_index_violation(): + for u in Customer.nodes.all(): + u.delete() + + users = Customer.create( + {"email": "jim6@aol.com", "age": 3}, + ) + assert users + with raises(UniqueProperty): + Customer.create( + {"email": "jim6@aol.com", "age": 3}, + {"email": "jim7@aol.com", "age": 5}, + ) + + # not found + assert not Customer.nodes.filter(email="jim7@aol.com").__bool__() + + +class Dog(StructuredNode): + name = StringProperty(required=True) + owner = RelationshipTo("Person", "owner") + + +class Person(StructuredNode): + name = StringProperty(unique_index=True) + pets = RelationshipFrom("Dog", "owner") + + +@mark_sync_test +def test_get_or_create_with_rel(): + create_bob = Person.get_or_create({"name": "Bob"}) + bob = create_bob[0] + bobs_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) + + create_tim = Person.get_or_create({"name": "Tim"}) + tim = create_tim[0] + tims_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) + + # not the same gizmo + assert bobs_gizmo[0] != tims_gizmo[0] From d4ea57989b0befdb96d5211856af34cd6ff23160 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 29 Dec 2023 16:08:34 +0100 Subject: [PATCH 19/73] Extract PropertyManager for async reasons --- doc/source/module_documentation_async.rst | 6 + doc/source/module_documentation_sync.rst | 6 + neomodel/__init__.py | 1 + neomodel/async_/core.py | 5 +- neomodel/async_/property_manager.py | 109 ++++++++++++++ neomodel/async_/relationship.py | 5 +- neomodel/contrib/spatial_properties.py | 170 ++++++++++++++-------- neomodel/match_q.py | 6 +- neomodel/properties.py | 108 +------------- neomodel/sync_/cardinality.py | 2 +- neomodel/sync_/core.py | 11 +- neomodel/sync_/property_manager.py | 109 ++++++++++++++ neomodel/sync_/relationship.py | 5 +- neomodel/sync_/relationship_manager.py | 13 +- 14 files changed, 360 insertions(+), 196 deletions(-) create mode 100644 neomodel/async_/property_manager.py create mode 100644 neomodel/sync_/property_manager.py diff --git a/doc/source/module_documentation_async.rst b/doc/source/module_documentation_async.rst index 09ce3e2b..2150b235 100644 --- a/doc/source/module_documentation_async.rst +++ b/doc/source/module_documentation_async.rst @@ -27,6 +27,12 @@ Relationships :members: :show-inheritance: +Property Manager +================ +.. automodule:: neomodel.async_.property_manager + :members: + :show-inheritance: + Paths ===== diff --git a/doc/source/module_documentation_sync.rst b/doc/source/module_documentation_sync.rst index e39485f2..9eb642fe 100644 --- a/doc/source/module_documentation_sync.rst +++ b/doc/source/module_documentation_sync.rst @@ -27,6 +27,12 @@ Relationships :members: :show-inheritance: +Property Manager +================ +.. automodule:: neomodel.sync_.property_manager + :members: + :show-inheritance: + Paths ===== diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 4da0c3eb..399fb0f9 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -48,6 +48,7 @@ from neomodel.sync_.core import StructuredNode from neomodel.sync_.match import NodeSet, Traversal from neomodel.sync_.path import NeomodelPath +from neomodel.sync_.property_manager import PropertyManager from neomodel.sync_.relationship import StructuredRel from neomodel.sync_.relationship_manager import ( Relationship, diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index ab8faeda..c99ce260 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -22,6 +22,7 @@ from neo4j.graph import Node, Path, Relationship from neomodel import config +from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.exceptions import ( ConstraintValidationFailed, DoesNotExist, @@ -32,7 +33,7 @@ UniqueProperty, ) from neomodel.hooks import hooks -from neomodel.properties import Property, PropertyManager +from neomodel.properties import Property from neomodel.util import ( _get_node_properties, _UnsavedNode, @@ -1064,7 +1065,7 @@ def build_class_registry(cls): raise NodeClassAlreadyDefined(cls, adb._NODE_CLASS_REGISTRY) -NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) +NodeBase = NodeMeta("NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True}) class AsyncStructuredNode(NodeBase): diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py new file mode 100644 index 00000000..b9401dab --- /dev/null +++ b/neomodel/async_/property_manager.py @@ -0,0 +1,109 @@ +import types + +from neomodel.exceptions import RequiredProperty +from neomodel.properties import AliasProperty, Property + + +def display_for(key): + def display_choice(self): + return getattr(self.__class__, key).choices[getattr(self, key)] + + return display_choice + + +class AsyncPropertyManager: + """ + Common methods for handling properties on node and relationship objects. + """ + + def __init__(self, **kwargs): + properties = getattr(self, "__all_properties__", None) + if properties is None: + properties = self.defined_properties(rels=False, aliases=False).items() + for name, property in properties: + if kwargs.get(name) is None: + if getattr(property, "has_default", False): + setattr(self, name, property.default_value()) + else: + setattr(self, name, None) + else: + setattr(self, name, kwargs[name]) + + if getattr(property, "choices", None): + setattr( + self, + f"get_{name}_display", + types.MethodType(display_for(name), self), + ) + + if name in kwargs: + del kwargs[name] + + aliases = getattr(self, "__all_aliases__", None) + if aliases is None: + aliases = self.defined_properties( + aliases=True, rels=False, properties=False + ).items() + for name, property in aliases: + if name in kwargs: + setattr(self, name, kwargs[name]) + del kwargs[name] + + # undefined properties (for magic @prop.setters etc) + for name, property in kwargs.items(): + setattr(self, name, property) + + @property + def __properties__(self): + from neomodel.async_.relationship_manager import AsyncRelationshipManager + + return dict( + (name, value) + for name, value in vars(self).items() + if not name.startswith("_") + and not callable(value) + and not isinstance( + value, + ( + AsyncRelationshipManager, + AliasProperty, + ), + ) + ) + + @classmethod + def deflate(cls, properties, obj=None, skip_empty=False): + # deflate dict ready to be stored + deflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.db_property or name + if properties.get(name) is not None: + deflated[db_property] = property.deflate(properties[name], obj) + elif property.has_default: + deflated[db_property] = property.deflate(property.default_value(), obj) + elif property.required: + raise RequiredProperty(name, cls) + elif not skip_empty: + deflated[db_property] = None + return deflated + + @classmethod + def defined_properties(cls, aliases=True, properties=True, rels=True): + from neomodel.async_.relationship_manager import AsyncRelationshipDefinition + + props = {} + for baseclass in reversed(cls.__mro__): + props.update( + dict( + (name, property) + for name, property in vars(baseclass).items() + if (aliases and isinstance(property, AliasProperty)) + or ( + properties + and isinstance(property, Property) + and not isinstance(property, AliasProperty) + ) + or (rels and isinstance(property, AsyncRelationshipDefinition)) + ) + ) + return props diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index 355003d4..eab91249 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -1,6 +1,7 @@ from neomodel.async_.core import adb +from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.hooks import hooks -from neomodel.properties import Property, PropertyManager +from neomodel.properties import Property class RelationshipMeta(type): @@ -35,7 +36,7 @@ def __new__(mcs, name, bases, dct): return inst -StructuredRelBase = RelationshipMeta("RelationshipBase", (PropertyManager,), {}) +StructuredRelBase = RelationshipMeta("RelationshipBase", (AsyncPropertyManager,), {}) class AsyncStructuredRel(StructuredRelBase): diff --git a/neomodel/contrib/spatial_properties.py b/neomodel/contrib/spatial_properties.py index 6b6606d5..982e6921 100644 --- a/neomodel/contrib/spatial_properties.py +++ b/neomodel/contrib/spatial_properties.py @@ -27,7 +27,7 @@ try: from shapely import __version__ as shapely_version from shapely.coords import CoordinateSequence - from shapely.geometry import Point as ShapelyPoint + from shapely.geometry import Point as ShapelyPoint except ImportError as exc: raise ImportError( "NEOMODEL ERROR: Shapely not found. If required, you can install Shapely via " @@ -53,10 +53,11 @@ # Taking into account the Shapely 2.0.0 changes in the way POINT objects are initialisd. if int("".join(shapely_version.split(".")[0:3])) < 200: + class NeomodelPoint(ShapelyPoint): """ Abstracts the Point spatial data type of Neo4j. - + Note: At the time of writing, Neo4j supports 2 main variants of Point: 1. A generic point defined over a Cartesian plane @@ -65,12 +66,12 @@ class NeomodelPoint(ShapelyPoint): * The minimum data to define a point is longitude, latitude [,Height] and the crs is then assumed to be "wgs-84". """ - + # def __init__(self, *args, crs=None, x=None, y=None, z=None, latitude=None, longitude=None, height=None, **kwargs): def __init__(self, *args, **kwargs): """ Creates a NeomodelPoint. - + :param args: Positional arguments to emulate the behaviour of Shapely's Point (and specifically the copy constructor) :type args: list @@ -91,7 +92,7 @@ def __init__(self, *args, **kwargs): :param kwargs: Dictionary of keyword arguments :type kwargs: dict """ - + # Python2.7 Workaround for the order that the arguments get passed to the functions crs = kwargs.pop("crs", None) x = kwargs.pop("x", None) @@ -100,14 +101,16 @@ def __init__(self, *args, **kwargs): longitude = kwargs.pop("longitude", None) latitude = kwargs.pop("latitude", None) height = kwargs.pop("height", None) - + _x, _y, _z = None, None, None - + # CRS validity check is common to both types of constructors that follow if crs is not None and crs not in ACCEPTABLE_CRS: - raise ValueError(f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}") + raise ValueError( + f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}" + ) self._crs = crs - + # If positional arguments have been supplied, then this is a possible call to the copy constructor or # initialisation by a coordinate iterable as per ShapelyPoint constructor. if len(args) > 0: @@ -115,7 +118,9 @@ def __init__(self, *args, **kwargs): if isinstance(args[0], (tuple, list)): # Check dimensionality of tuple if len(args[0]) < 2 or len(args[0]) > 3: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}" + ) x = args[0][0] y = args[0][1] if len(args[0]) == 3: @@ -143,11 +148,15 @@ def __init__(self, *args, **kwargs): if crs is None: self._crs = "cartesian-3d" else: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}" + ) return else: - raise TypeError(f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}") - + raise TypeError( + f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}" + ) + # Initialisation is either via x,y[,z] XOR longitude,latitude[,height]. Specifying both leads to an error. if any(i is not None for i in [x, y, z]) and any( i is not None for i in [latitude, longitude, height] @@ -157,14 +166,14 @@ def __init__(self, *args, **kwargs): "A Point can be defined either by x,y,z coordinates OR latitude,longitude,height but not " "a combination of these terms" ) - + # Specifying no initialisation argument at this point in the constructor is flagged as an error if all(i is None for i in [x, y, z, latitude, longitude, height]): raise ValueError( "Invalid instantiation via no arguments. " "A Point needs default values either in x,y,z or longitude, latitude, height coordinates" ) - + # Geographical Point Initialisation if latitude is not None and longitude is not None: if height is not None: @@ -176,7 +185,7 @@ def __init__(self, *args, **kwargs): self._crs = "wgs-84" _x = longitude _y = latitude - + # Geometrical Point Initialisation if x is not None and y is not None: if z is not None: @@ -188,22 +197,26 @@ def __init__(self, *args, **kwargs): self._crs = "cartesian" _x = x _y = y - + if _z is None: if "-3d" not in self._crs: super().__init__((float(_x), float(_y)), **kwargs) else: - raise ValueError(f"Invalid vector dimensions(2) for given CRS({self._crs}).") + raise ValueError( + f"Invalid vector dimensions(2) for given CRS({self._crs})." + ) else: if "-3d" in self._crs: super().__init__((float(_x), float(_y), float(_z)), **kwargs) else: - raise ValueError(f"Invalid vector dimensions(3) for given CRS({self._crs}).") - + raise ValueError( + f"Invalid vector dimensions(3) for given CRS({self._crs})." + ) + @property def crs(self): return self._crs - + @property def x(self): if not self._crs.startswith("cartesian"): @@ -211,7 +224,7 @@ def x(self): f'Invalid coordinate ("x") for points defined over {self.crs}' ) return super().x - + @property def y(self): if not self._crs.startswith("cartesian"): @@ -219,7 +232,7 @@ def y(self): f'Invalid coordinate ("y") for points defined over {self.crs}' ) return super().y - + @property def z(self): if self._crs != "cartesian-3d": @@ -227,7 +240,7 @@ def z(self): f'Invalid coordinate ("z") for points defined over {self.crs}' ) return super().z - + @property def latitude(self): if not self._crs.startswith("wgs-84"): @@ -235,7 +248,7 @@ def latitude(self): f'Invalid coordinate ("latitude") for points defined over {self.crs}' ) return super().y - + @property def longitude(self): if not self._crs.startswith("wgs-84"): @@ -243,7 +256,7 @@ def longitude(self): f'Invalid coordinate ("longitude") for points defined over {self.crs}' ) return super().x - + @property def height(self): if self._crs != "wgs-84-3d": @@ -251,21 +264,22 @@ def height(self): f'Invalid coordinate ("height") for points defined over {self.crs}' ) return super().z - + # The following operations are necessary here due to the way queries (and more importantly their parameters) get # combined and evaluated in neomodel. Specifically, query expressions get duplicated with deep copies and any valid # datatype values should also implement these operations. def __copy__(self): return NeomodelPoint(self) - + def __deepcopy__(self, memo): return NeomodelPoint(self) else: + class NeomodelPoint: """ Abstracts the Point spatial data type of Neo4j. - + Note: At the time of writing, Neo4j supports 2 main variants of Point: 1. A generic point defined over a Cartesian plane @@ -274,12 +288,12 @@ class NeomodelPoint: * The minimum data to define a point is longitude, latitude [,Height] and the crs is then assumed to be "wgs-84". """ - + # def __init__(self, *args, crs=None, x=None, y=None, z=None, latitude=None, longitude=None, height=None, **kwargs): def __init__(self, *args, **kwargs): """ Creates a NeomodelPoint. - + :param args: Positional arguments to emulate the behaviour of Shapely's Point (and specifically the copy constructor) :type args: list @@ -300,7 +314,7 @@ def __init__(self, *args, **kwargs): :param kwargs: Dictionary of keyword arguments :type kwargs: dict """ - + # Python2.7 Workaround for the order that the arguments get passed to the functions crs = kwargs.pop("crs", None) x = kwargs.pop("x", None) @@ -309,14 +323,16 @@ def __init__(self, *args, **kwargs): longitude = kwargs.pop("longitude", None) latitude = kwargs.pop("latitude", None) height = kwargs.pop("height", None) - + _x, _y, _z = None, None, None - + # CRS validity check is common to both types of constructors that follow if crs is not None and crs not in ACCEPTABLE_CRS: - raise ValueError(f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}") + raise ValueError( + f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}" + ) self._crs = crs - + # If positional arguments have been supplied, then this is a possible call to the copy constructor or # initialisation by a coordinate iterable as per ShapelyPoint constructor. if len(args) > 0: @@ -324,7 +340,9 @@ def __init__(self, *args, **kwargs): if isinstance(args[0], (tuple, list)): # Check dimensionality of tuple if len(args[0]) < 2 or len(args[0]) > 3: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}" + ) x = args[0][0] y = args[0][1] if len(args[0]) == 3: @@ -354,11 +372,15 @@ def __init__(self, *args, **kwargs): if crs is None: self._crs = "cartesian-3d" else: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}" + ) return else: - raise TypeError(f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}") - + raise TypeError( + f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}" + ) + # Initialisation is either via x,y[,z] XOR longitude,latitude[,height]. Specifying both leads to an error. if any(i is not None for i in [x, y, z]) and any( i is not None for i in [latitude, longitude, height] @@ -368,14 +390,14 @@ def __init__(self, *args, **kwargs): "A Point can be defined either by x,y,z coordinates OR latitude,longitude,height but not " "a combination of these terms" ) - + # Specifying no initialisation argument at this point in the constructor is flagged as an error if all(i is None for i in [x, y, z, latitude, longitude, height]): raise ValueError( "Invalid instantiation via no arguments. " "A Point needs default values either in x,y,z or longitude, latitude, height coordinates" ) - + # Geographical Point Initialisation if latitude is not None and longitude is not None: if height is not None: @@ -387,7 +409,7 @@ def __init__(self, *args, **kwargs): self._crs = "wgs-84" _x = longitude _y = latitude - + # Geometrical Point Initialisation if x is not None and y is not None: if z is not None: @@ -399,23 +421,28 @@ def __init__(self, *args, **kwargs): self._crs = "cartesian" _x = x _y = y - + if _z is None: if "-3d" not in self._crs: self._shapely_point = ShapelyPoint((float(_x), float(_y))) else: - raise ValueError(f"Invalid vector dimensions(2) for given CRS({self._crs}).") + raise ValueError( + f"Invalid vector dimensions(2) for given CRS({self._crs})." + ) else: if "-3d" in self._crs: - self._shapely_point = ShapelyPoint((float(_x), float(_y), float(_z))) + self._shapely_point = ShapelyPoint( + (float(_x), float(_y), float(_z)) + ) else: - raise ValueError(f"Invalid vector dimensions(3) for given CRS({self._crs}).") + raise ValueError( + f"Invalid vector dimensions(3) for given CRS({self._crs})." + ) - @property def crs(self): return self._crs - + @property def x(self): if not self._crs.startswith("cartesian"): @@ -423,7 +450,7 @@ def x(self): f'Invalid coordinate ("x") for points defined over {self.crs}' ) return self._shapely_point.x - + @property def y(self): if not self._crs.startswith("cartesian"): @@ -431,7 +458,7 @@ def y(self): f'Invalid coordinate ("y") for points defined over {self.crs}' ) return self._shapely_point.y - + @property def z(self): if self._crs != "cartesian-3d": @@ -439,7 +466,7 @@ def z(self): f'Invalid coordinate ("z") for points defined over {self.crs}' ) return self._shapely_point.z - + @property def latitude(self): if not self._crs.startswith("wgs-84"): @@ -447,7 +474,7 @@ def latitude(self): f'Invalid coordinate ("latitude") for points defined over {self.crs}' ) return self._shapely_point.y - + @property def longitude(self): if not self._crs.startswith("wgs-84"): @@ -455,7 +482,7 @@ def longitude(self): f'Invalid coordinate ("longitude") for points defined over {self.crs}' ) return self._shapely_point.x - + @property def height(self): if self._crs != "wgs-84-3d": @@ -463,13 +490,13 @@ def height(self): f'Invalid coordinate ("height") for points defined over {self.crs}' ) return self._shapely_point.z - + # The following operations are necessary here due to the way queries (and more importantly their parameters) get # combined and evaluated in neomodel. Specifically, query expressions get duplicated with deep copies and any valid # datatype values should also implement these operations. def __copy__(self): return NeomodelPoint(self) - + def __deepcopy__(self, memo): return NeomodelPoint(self) @@ -484,7 +511,9 @@ def __eq__(self, other): Compare objects by value """ if not isinstance(other, (ShapelyPoint, NeomodelPoint)): - raise ValueException(f"NeomodelPoint equality comparison expected NeomodelPoint or Shapely Point, received {type(other)}") + raise ValueException( + f"NeomodelPoint equality comparison expected NeomodelPoint or Shapely Point, received {type(other)}" + ) else: if isinstance(other, ShapelyPoint): return self.coords[0] == other.coords[0] @@ -517,12 +546,19 @@ def __init__(self, *args, **kwargs): crs = None if crs is None or (crs not in ACCEPTABLE_CRS): - raise ValueError(f"Invalid CRS({crs}). Point properties require CRS to be one of {','.join(ACCEPTABLE_CRS)}") + raise ValueError( + f"Invalid CRS({crs}). Point properties require CRS to be one of {','.join(ACCEPTABLE_CRS)}" + ) # If a default value is passed and it is not a callable, then make sure it is in the right type - if "default" in kwargs and not hasattr(kwargs["default"], "__call__") and not isinstance(kwargs["default"], NeomodelPoint): - raise TypeError(f"Invalid default value. Expected NeomodelPoint, received {type(kwargs['default'])}" - ) + if ( + "default" in kwargs + and not hasattr(kwargs["default"], "__call__") + and not isinstance(kwargs["default"], NeomodelPoint) + ): + raise TypeError( + f"Invalid default value. Expected NeomodelPoint, received {type(kwargs['default'])}" + ) super().__init__(*args, **kwargs) self._crs = crs @@ -544,10 +580,14 @@ def inflate(self, value): try: value_point_crs = SRID_TO_CRS[value.srid] except KeyError as e: - raise ValueError(f"Invalid SRID to inflate. Expected one of {SRID_TO_CRS.keys()}, received {value.srid}") from e + raise ValueError( + f"Invalid SRID to inflate. Expected one of {SRID_TO_CRS.keys()}, received {value.srid}" + ) from e if self._crs != value_point_crs: - raise ValueError(f"Invalid CRS. Expected POINT defined over {self._crs}, received {value_point_crs}") + raise ValueError( + f"Invalid CRS. Expected POINT defined over {self._crs}, received {value_point_crs}" + ) # cartesian if value.srid == 7203: return NeomodelPoint(x=value.x, y=value.y) @@ -581,7 +621,9 @@ def deflate(self, value): ) if value.crs != self._crs: - raise ValueError(f"Invalid CRS. Expected NeomodelPoint defined over {self._crs}, received NeomodelPoint defined over {value.crs}") + raise ValueError( + f"Invalid CRS. Expected NeomodelPoint defined over {self._crs}, received NeomodelPoint defined over {value.crs}" + ) if value.crs == "cartesian-3d": return neo4j.spatial.CartesianPoint((value.x, value.y, value.z)) diff --git a/neomodel/match_q.py b/neomodel/match_q.py index 7e76a23b..4e45588c 100644 --- a/neomodel/match_q.py +++ b/neomodel/match_q.py @@ -69,7 +69,11 @@ def _new_instance(cls, children=None, connector=None, negated=False): return obj def __str__(self): - return f"(NOT ({self.connector}: {', '.join(str(c) for c in self.children)}))" if self.negated else f"({self.connector}: {', '.join(str(c) for c in self.children)})" + return ( + f"(NOT ({self.connector}: {', '.join(str(c) for c in self.children)}))" + if self.negated + else f"({self.connector}: {', '.join(str(c) for c in self.children)})" + ) def __repr__(self): return f"<{self.__class__.__name__}: {self}>" diff --git a/neomodel/properties.py b/neomodel/properties.py index 51b59449..cac92dfe 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -2,7 +2,6 @@ import json import re import sys -import types import uuid from datetime import date, datetime @@ -10,117 +9,12 @@ import pytz from neomodel import config -from neomodel.exceptions import DeflateError, InflateError, RequiredProperty +from neomodel.exceptions import DeflateError, InflateError if sys.version_info >= (3, 0): Unicode = str -def display_for(key): - def display_choice(self): - return getattr(self.__class__, key).choices[getattr(self, key)] - - return display_choice - - -class PropertyManager: - """ - Common methods for handling properties on node and relationship objects. - """ - - def __init__(self, **kwargs): - properties = getattr(self, "__all_properties__", None) - if properties is None: - properties = self.defined_properties(rels=False, aliases=False).items() - for name, property in properties: - if kwargs.get(name) is None: - if getattr(property, "has_default", False): - setattr(self, name, property.default_value()) - else: - setattr(self, name, None) - else: - setattr(self, name, kwargs[name]) - - if getattr(property, "choices", None): - setattr( - self, - f"get_{name}_display", - types.MethodType(display_for(name), self), - ) - - if name in kwargs: - del kwargs[name] - - aliases = getattr(self, "__all_aliases__", None) - if aliases is None: - aliases = self.defined_properties( - aliases=True, rels=False, properties=False - ).items() - for name, property in aliases: - if name in kwargs: - setattr(self, name, kwargs[name]) - del kwargs[name] - - # undefined properties (for magic @prop.setters etc) - for name, property in kwargs.items(): - setattr(self, name, property) - - @property - def __properties__(self): - from neomodel.async_.relationship_manager import AsyncRelationshipManager - - return dict( - (name, value) - for name, value in vars(self).items() - if not name.startswith("_") - and not callable(value) - and not isinstance( - value, - ( - AsyncRelationshipManager, - AliasProperty, - ), - ) - ) - - @classmethod - def deflate(cls, properties, obj=None, skip_empty=False): - # deflate dict ready to be stored - deflated = {} - for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name - if properties.get(name) is not None: - deflated[db_property] = property.deflate(properties[name], obj) - elif property.has_default: - deflated[db_property] = property.deflate(property.default_value(), obj) - elif property.required: - raise RequiredProperty(name, cls) - elif not skip_empty: - deflated[db_property] = None - return deflated - - @classmethod - def defined_properties(cls, aliases=True, properties=True, rels=True): - from neomodel.async_.relationship_manager import AsyncRelationshipDefinition - - props = {} - for baseclass in reversed(cls.__mro__): - props.update( - dict( - (name, property) - for name, property in vars(baseclass).items() - if (aliases and isinstance(property, AliasProperty)) - or ( - properties - and isinstance(property, Property) - and not isinstance(property, AliasProperty) - ) - or (rels and isinstance(property, AsyncRelationshipDefinition)) - ) - ) - return props - - def validator(fn): fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ if fn_name == "inflate": diff --git a/neomodel/sync_/cardinality.py b/neomodel/sync_/cardinality.py index b8b7b10e..4757eba5 100644 --- a/neomodel/sync_/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,8 +1,8 @@ +from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) -from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation class ZeroOrOne(RelationshipManager): diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 13144f42..20697612 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -32,7 +32,8 @@ UniqueProperty, ) from neomodel.hooks import hooks -from neomodel.properties import Property, PropertyManager +from neomodel.properties import Property +from neomodel.sync_.property_manager import PropertyManager from neomodel.util import ( _get_node_properties, _UnsavedNode, @@ -253,9 +254,7 @@ def begin(self, access_mode=None, **parameters): impersonated_user=self.impersonated_user, **parameters, ) - self._active_transaction: Transaction = ( - self._session.begin_transaction() - ) + self._active_transaction: Transaction = self._session.begin_transaction() @ensure_connection def commit(self): @@ -835,9 +834,7 @@ def change_neo4j_password(db: Database, user, new_password): db.change_neo4j_password(user, new_password) -def clear_neo4j_database( - db: Database, clear_constraints=False, clear_indexes=False -): +def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=False): deprecated( """ This method has been moved to the Database singleton (db for sync, db for async). diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py new file mode 100644 index 00000000..85452f0b --- /dev/null +++ b/neomodel/sync_/property_manager.py @@ -0,0 +1,109 @@ +import types + +from neomodel.exceptions import RequiredProperty +from neomodel.properties import AliasProperty, Property + + +def display_for(key): + def display_choice(self): + return getattr(self.__class__, key).choices[getattr(self, key)] + + return display_choice + + +class PropertyManager: + """ + Common methods for handling properties on node and relationship objects. + """ + + def __init__(self, **kwargs): + properties = getattr(self, "__all_properties__", None) + if properties is None: + properties = self.defined_properties(rels=False, aliases=False).items() + for name, property in properties: + if kwargs.get(name) is None: + if getattr(property, "has_default", False): + setattr(self, name, property.default_value()) + else: + setattr(self, name, None) + else: + setattr(self, name, kwargs[name]) + + if getattr(property, "choices", None): + setattr( + self, + f"get_{name}_display", + types.MethodType(display_for(name), self), + ) + + if name in kwargs: + del kwargs[name] + + aliases = getattr(self, "__all_aliases__", None) + if aliases is None: + aliases = self.defined_properties( + aliases=True, rels=False, properties=False + ).items() + for name, property in aliases: + if name in kwargs: + setattr(self, name, kwargs[name]) + del kwargs[name] + + # undefined properties (for magic @prop.setters etc) + for name, property in kwargs.items(): + setattr(self, name, property) + + @property + def __properties__(self): + from neomodel.sync_.relationship_manager import RelationshipManager + + return dict( + (name, value) + for name, value in vars(self).items() + if not name.startswith("_") + and not callable(value) + and not isinstance( + value, + ( + RelationshipManager, + AliasProperty, + ), + ) + ) + + @classmethod + def deflate(cls, properties, obj=None, skip_empty=False): + # deflate dict ready to be stored + deflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.db_property or name + if properties.get(name) is not None: + deflated[db_property] = property.deflate(properties[name], obj) + elif property.has_default: + deflated[db_property] = property.deflate(property.default_value(), obj) + elif property.required: + raise RequiredProperty(name, cls) + elif not skip_empty: + deflated[db_property] = None + return deflated + + @classmethod + def defined_properties(cls, aliases=True, properties=True, rels=True): + from neomodel.sync_.relationship_manager import RelationshipDefinition + + props = {} + for baseclass in reversed(cls.__mro__): + props.update( + dict( + (name, property) + for name, property in vars(baseclass).items() + if (aliases and isinstance(property, AliasProperty)) + or ( + properties + and isinstance(property, Property) + and not isinstance(property, AliasProperty) + ) + or (rels and isinstance(property, RelationshipDefinition)) + ) + ) + return props diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 4e5a7a71..c246df38 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,6 +1,7 @@ -from neomodel.sync_.core import db from neomodel.hooks import hooks -from neomodel.properties import Property, PropertyManager +from neomodel.properties import Property +from neomodel.sync_.core import db +from neomodel.sync_.property_manager import PropertyManager class RelationshipMeta(type): diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 64512c56..cd43a5cc 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -3,15 +3,10 @@ import sys from importlib import import_module +from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.sync_.core import db -from neomodel.sync_.match import ( - NodeSet, - Traversal, - _rel_helper, - _rel_merge_helper, -) +from neomodel.sync_.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper from neomodel.sync_.relationship import StructuredRel -from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.util import ( EITHER, INCOMING, @@ -252,9 +247,7 @@ def reconnect(self, old_node, new_node): q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) q += " WITH r DELETE r" - self.source.cypher( - q, {"old": old_node.element_id, "new": new_node.element_id} - ) + self.source.cypher(q, {"old": old_node.element_id, "new": new_node.element_id}) @check_source def disconnect(self, node): From 8f90b4b65de23d69950ff5dfd12c6f6e2a58a23b Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 29 Dec 2023 16:47:35 +0100 Subject: [PATCH 20/73] Fix cardinality and test --- bin/make-unasync | 4 +- neomodel/async_/cardinality.py | 30 ++-- neomodel/async_/match.py | 5 +- neomodel/async_/relationship_manager.py | 16 +- neomodel/contrib/sync_/semi_structured.py | 2 +- neomodel/sync_/cardinality.py | 6 +- neomodel/sync_/core.py | 2 +- neomodel/sync_/match.py | 5 +- neomodel/sync_/relationship.py | 4 +- neomodel/sync_/relationship_manager.py | 9 +- test/async_/test_cardinality.py | 187 ++++++++++++++++++++++ test/sync/test_alias.py | 2 +- test/sync/test_batch.py | 4 +- test/{ => sync}/test_cardinality.py | 71 ++++---- test/sync/test_cypher.py | 2 +- 15 files changed, 279 insertions(+), 70 deletions(-) create mode 100644 test/async_/test_cardinality.py rename test/{ => sync}/test_cardinality.py (69%) diff --git a/bin/make-unasync b/bin/make-unasync index 4f6a8ec2..7b796bfa 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -208,12 +208,12 @@ def apply_unasync(files): "adb": "db", "async_": "sync_", "check_bool": "__bool__", - "check_non_zero": "__nonzero__", + "check_nonzero": "__nonzero__", } additional_test_replacements = { "async_": "sync_", "check_bool": "__bool__", - "check_non_zero": "__nonzero__", + "check_nonzero": "__nonzero__", "adb": "db", "mark_async_test": "mark_sync_test", "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture", diff --git a/neomodel/async_/cardinality.py b/neomodel/async_/cardinality.py index 0c3b02cf..ff1eb779 100644 --- a/neomodel/async_/cardinality.py +++ b/neomodel/async_/cardinality.py @@ -10,21 +10,21 @@ class AsyncZeroOrOne(AsyncRelationshipManager): description = "zero or one relationship" - def single(self): + async def single(self): """ Return the associated node. :return: node """ - nodes = super().all() + nodes = await super().all() if len(nodes) == 1: return nodes[0] if len(nodes) > 1: raise CardinalityViolation(self, len(nodes)) return None - def all(self): - node = self.single() + async def all(self): + node = await self.single() return [node] if node else [] async def connect(self, node, properties=None): @@ -37,7 +37,7 @@ async def connect(self, node, properties=None): :type: dict :return: True / rel instance """ - if len(self): + if await super().__len__(): raise AttemptedCardinalityViolation( f"Node already has {self} can't connect more" ) @@ -49,24 +49,24 @@ class AsyncOneOrMore(AsyncRelationshipManager): description = "one or more relationships" - def single(self): + async def single(self): """ Fetch one of the related nodes :return: Node """ - nodes = super().all() + nodes = await super().all() if nodes: return nodes[0] raise CardinalityViolation(self, "none") - def all(self): + async def all(self): """ Returns all related nodes. :return: [node1, node2...] """ - nodes = super().all() + nodes = await super().all() if nodes: return nodes raise CardinalityViolation(self, "none") @@ -77,7 +77,7 @@ async def disconnect(self, node): :param node: :return: """ - if super().__len__() < 2: + if await super().__len__() < 2: raise AttemptedCardinalityViolation("One or more expected") return await super().disconnect(node) @@ -89,26 +89,26 @@ class AsyncOne(AsyncRelationshipManager): description = "one relationship" - def single(self): + async def single(self): """ Return the associated node. :return: node """ - nodes = super().all() + nodes = await super().all() if nodes: if len(nodes) == 1: return nodes[0] raise CardinalityViolation(self, len(nodes)) raise CardinalityViolation(self, "none") - def all(self): + async def all(self): """ Return single node in an array :return: [node] """ - return [self.single()] + return [await self.single()] async def disconnect(self, node): raise AttemptedCardinalityViolation( @@ -130,6 +130,6 @@ async def connect(self, node, properties=None): """ if not hasattr(self.source, "element_id") or self.source.element_id is None: raise ValueError("Node has not been saved cannot connect!") - if len(self): + if await super().__len__(): raise AttemptedCardinalityViolation("Node already has one relationship") return await super().connect(node, properties) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 7f092577..2c8388ac 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -751,7 +751,7 @@ async def check_bool(self): _count = await self.query_cls(self).build_ast()._count() return _count > 0 - async def check_non_zero(self): + async def check_nonzero(self): """ Override for __bool__ dunder method. :return: True if the set contains any node, False otherwise @@ -783,7 +783,8 @@ async def __getitem__(self, key): self.skip = key self.limit = 1 - return await self.query_cls(self).build_ast()._execute()[0] + _items = await self.query_cls(self).build_ast()._execute() + return _items[0] return None diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 2a1b95e2..7e4e421a 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -343,14 +343,14 @@ def is_connected(self, node): """ return self._new_traversal().__contains__(node) - def single(self): + async def single(self): """ Get a single related node or none. :return: StructuredNode """ try: - return self[0] + return await self[0] except IndexError: pass @@ -363,25 +363,25 @@ def match(self, **kwargs): """ return self._new_traversal().match(**kwargs) - def all(self): + async def all(self): """ Return all related nodes. :return: list """ - return self._new_traversal().all() + return await self._new_traversal().all() - def __iter__(self): - return self._new_traversal().__iter__() + async def __aiter__(self): + return self._new_traversal().__aiter__() def __len__(self): return self._new_traversal().__len__() def __bool__(self): - return self._new_traversal().__bool__() + return self._new_traversal().check_bool() def __nonzero__(self): - return self._new_traversal().__nonzero__() + return self._new_traversal().check_nonzero() def __contains__(self, obj): return self._new_traversal().__contains__(obj) diff --git a/neomodel/contrib/sync_/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py index 86a5a140..dcf7873a 100644 --- a/neomodel/contrib/sync_/semi_structured.py +++ b/neomodel/contrib/sync_/semi_structured.py @@ -1,5 +1,5 @@ -from neomodel.exceptions import DeflateConflict, InflateConflict from neomodel.sync_.core import StructuredNode +from neomodel.exceptions import DeflateConflict, InflateConflict from neomodel.util import _get_node_properties diff --git a/neomodel/sync_/cardinality.py b/neomodel/sync_/cardinality.py index 4757eba5..bf00f78d 100644 --- a/neomodel/sync_/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,8 +1,8 @@ -from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) +from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation class ZeroOrOne(RelationshipManager): @@ -37,7 +37,7 @@ def connect(self, node, properties=None): :type: dict :return: True / rel instance """ - if len(self): + if super().__len__(): raise AttemptedCardinalityViolation( f"Node already has {self} can't connect more" ) @@ -130,6 +130,6 @@ def connect(self, node, properties=None): """ if not hasattr(self.source, "element_id") or self.source.element_id is None: raise ValueError("Node has not been saved cannot connect!") - if len(self): + if super().__len__(): raise AttemptedCardinalityViolation("Node already has one relationship") return super().connect(node, properties) diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 20697612..8e17bdde 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -22,6 +22,7 @@ from neo4j.graph import Node, Path, Relationship from neomodel import config +from neomodel.sync_.property_manager import PropertyManager from neomodel.exceptions import ( ConstraintValidationFailed, DoesNotExist, @@ -33,7 +34,6 @@ ) from neomodel.hooks import hooks from neomodel.properties import Property -from neomodel.sync_.property_manager import PropertyManager from neomodel.util import ( _get_node_properties, _UnsavedNode, diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index deb92f9b..0a9c74b6 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from typing import Optional +from neomodel.sync_.core import StructuredNode, db from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty -from neomodel.sync_.core import StructuredNode, db from neomodel.util import INCOMING, OUTGOING @@ -781,7 +781,8 @@ def __getitem__(self, key): self.skip = key self.limit = 1 - return self.query_cls(self).build_ast()._execute()[0] + _items = self.query_cls(self).build_ast()._execute() + return _items[0] return None diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index c246df38..1845b534 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,7 +1,7 @@ -from neomodel.hooks import hooks -from neomodel.properties import Property from neomodel.sync_.core import db from neomodel.sync_.property_manager import PropertyManager +from neomodel.hooks import hooks +from neomodel.properties import Property class RelationshipMeta(type): diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index cd43a5cc..887b7b9d 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -3,10 +3,15 @@ import sys from importlib import import_module -from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.sync_.core import db -from neomodel.sync_.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper +from neomodel.sync_.match import ( + NodeSet, + Traversal, + _rel_helper, + _rel_merge_helper, +) from neomodel.sync_.relationship import StructuredRel +from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.util import ( EITHER, INCOMING, diff --git a/test/async_/test_cardinality.py b/test/async_/test_cardinality.py new file mode 100644 index 00000000..bafe919d --- /dev/null +++ b/test/async_/test_cardinality.py @@ -0,0 +1,187 @@ +from test._async_compat import mark_async_test +from pytest import raises + +from neomodel import ( + AsyncOne, + AsyncOneOrMore, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncZeroOrOne, + AttemptedCardinalityViolation, + CardinalityViolation, + IntegerProperty, + StringProperty, + AsyncZeroOrMore, +) + +from neomodel.async_.core import adb + + +class HairDryer(AsyncStructuredNode): + version = IntegerProperty() + + +class ScrewDriver(AsyncStructuredNode): + version = IntegerProperty() + + +class Car(AsyncStructuredNode): + version = IntegerProperty() + + +class Monkey(AsyncStructuredNode): + name = StringProperty() + dryers = AsyncRelationshipTo("HairDryer", "OWNS_DRYER", cardinality=AsyncZeroOrMore) + driver = AsyncRelationshipTo( + "ScrewDriver", "HAS_SCREWDRIVER", cardinality=AsyncZeroOrOne + ) + car = AsyncRelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore) + toothbrush = AsyncRelationshipTo( + "ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne + ) + + +class ToothBrush(AsyncStructuredNode): + name = StringProperty() + + +@mark_async_test +async def test_cardinality_zero_or_more(): + m = await Monkey(name="tim").save() + assert await m.dryers.all() == [] + single_dryer = await m.driver.single() + assert single_dryer is None + h = await HairDryer(version=1).save() + + await m.dryers.connect(h) + assert len(await m.dryers.all()) == 1 + single_dryer = await m.dryers.single() + assert single_dryer.version == 1 + + await m.dryers.disconnect(h) + assert await m.dryers.all() == [] + single_dryer = await m.driver.single() + assert single_dryer is None + + h2 = await HairDryer(version=2).save() + await m.dryers.connect(h) + await m.dryers.connect(h2) + await m.dryers.disconnect_all() + assert await m.dryers.all() == [] + single_dryer = await m.driver.single() + assert single_dryer is None + + +@mark_async_test +async def test_cardinality_zero_or_one(): + m = await Monkey(name="bob").save() + assert await m.driver.all() == [] + single_driver = await m.driver.single() + assert await m.driver.single() is None + h = await ScrewDriver(version=1).save() + + await m.driver.connect(h) + assert len(await m.driver.all()) == 1 + single_driver = await m.driver.single() + assert single_driver.version == 1 + + j = await ScrewDriver(version=2).save() + with raises(AttemptedCardinalityViolation): + await m.driver.connect(j) + + await m.driver.reconnect(h, j) + single_driver = await m.driver.single() + assert single_driver.version == 2 + + # Forcing creation of a second ToothBrush to go around + # AttemptedCardinalityViolation + await adb.cypher_query( + """ + MATCH (m:Monkey WHERE m.name="bob") + CREATE (s:ScrewDriver {version:3}) + WITH m, s + CREATE (m)-[:HAS_SCREWDRIVER]->(s) + """ + ) + with raises( + CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: 2." + ): + await m.driver.all() + + +@mark_async_test +async def test_cardinality_one_or_more(): + m = await Monkey(name="jerry").save() + + with raises(CardinalityViolation): + await m.car.all() + + with raises(CardinalityViolation): + await m.car.single() + + c = await Car(version=2).save() + await m.car.connect(c) + single_car = await m.car.single() + assert single_car.version == 2 + + cars = await m.car.all() + assert len(cars) == 1 + + with raises(AttemptedCardinalityViolation): + await m.car.disconnect(c) + + d = await Car(version=3).save() + await m.car.connect(d) + cars = await m.car.all() + assert len(cars) == 2 + + await m.car.disconnect(d) + cars = await m.car.all() + assert len(cars) == 1 + + +@mark_async_test +async def test_cardinality_one(): + m = await Monkey(name="jerry").save() + + with raises( + CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: none." + ): + await m.toothbrush.all() + + with raises(CardinalityViolation): + await m.toothbrush.single() + + b = await ToothBrush(name="Jim").save() + await m.toothbrush.connect(b) + single_toothbrush = await m.toothbrush.single() + assert single_toothbrush.name == "Jim" + + x = await ToothBrush(name="Jim").save() + with raises(AttemptedCardinalityViolation): + await m.toothbrush.connect(x) + + with raises(AttemptedCardinalityViolation): + await m.toothbrush.disconnect(b) + + with raises(AttemptedCardinalityViolation): + await m.toothbrush.disconnect_all() + + # Forcing creation of a second ToothBrush to go around + # AttemptedCardinalityViolation + await adb.cypher_query( + """ + MATCH (m:Monkey WHERE m.name="jerry") + CREATE (t:ToothBrush {name:"Jim"}) + WITH m, t + CREATE (m)-[:HAS_TOOTHBRUSH]->(t) + """ + ) + with raises( + CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: 2." + ): + await m.toothbrush.all() + + jp = Monkey(name="Jean-Pierre") + with raises(ValueError, match="Node has not been saved cannot connect!"): + await jp.toothbrush.connect(b) diff --git a/test/sync/test_alias.py b/test/sync/test_alias.py index f266eb82..c0d084a2 100644 --- a/test/sync/test_alias.py +++ b/test/sync/test_alias.py @@ -1,6 +1,6 @@ from test._async_compat import mark_sync_test -from neomodel import AliasProperty, StringProperty, StructuredNode +from neomodel import AliasProperty, StructuredNode, StringProperty class MagicProperty(AliasProperty): diff --git a/test/sync/test_batch.py b/test/sync/test_batch.py index 6823d8fd..8d5586e9 100644 --- a/test/sync/test_batch.py +++ b/test/sync/test_batch.py @@ -3,11 +3,11 @@ from pytest import raises from neomodel import ( - IntegerProperty, RelationshipFrom, RelationshipTo, - StringProperty, StructuredNode, + IntegerProperty, + StringProperty, UniqueIdProperty, config, ) diff --git a/test/test_cardinality.py b/test/sync/test_cardinality.py similarity index 69% rename from test/test_cardinality.py rename to test/sync/test_cardinality.py index 6be9226f..f3c27360 100644 --- a/test/test_cardinality.py +++ b/test/sync/test_cardinality.py @@ -1,90 +1,101 @@ +from test._async_compat import mark_sync_test from pytest import raises from neomodel import ( - AsyncOne, - AsyncOneOrMore, - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncZeroOrOne, + One, + OneOrMore, + RelationshipTo, + StructuredNode, + ZeroOrOne, AttemptedCardinalityViolation, CardinalityViolation, IntegerProperty, StringProperty, ZeroOrMore, - adb, ) +from neomodel.sync_.core import db -class HairDryer(AsyncStructuredNode): + +class HairDryer(StructuredNode): version = IntegerProperty() -class ScrewDriver(AsyncStructuredNode): +class ScrewDriver(StructuredNode): version = IntegerProperty() -class Car(AsyncStructuredNode): +class Car(StructuredNode): version = IntegerProperty() -class Monkey(AsyncStructuredNode): +class Monkey(StructuredNode): name = StringProperty() - dryers = AsyncRelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) - driver = AsyncRelationshipTo( - "ScrewDriver", "HAS_SCREWDRIVER", cardinality=AsyncZeroOrOne + dryers = RelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) + driver = RelationshipTo( + "ScrewDriver", "HAS_SCREWDRIVER", cardinality=ZeroOrOne ) - car = AsyncRelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore) - toothbrush = AsyncRelationshipTo( - "ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne + car = RelationshipTo("Car", "HAS_CAR", cardinality=OneOrMore) + toothbrush = RelationshipTo( + "ToothBrush", "HAS_TOOTHBRUSH", cardinality=One ) -class ToothBrush(AsyncStructuredNode): +class ToothBrush(StructuredNode): name = StringProperty() +@mark_sync_test def test_cardinality_zero_or_more(): m = Monkey(name="tim").save() assert m.dryers.all() == [] - assert m.dryers.single() is None + single_dryer = m.driver.single() + assert single_dryer is None h = HairDryer(version=1).save() m.dryers.connect(h) assert len(m.dryers.all()) == 1 - assert m.dryers.single().version == 1 + single_dryer = m.dryers.single() + assert single_dryer.version == 1 m.dryers.disconnect(h) assert m.dryers.all() == [] - assert m.dryers.single() is None + single_dryer = m.driver.single() + assert single_dryer is None h2 = HairDryer(version=2).save() m.dryers.connect(h) m.dryers.connect(h2) m.dryers.disconnect_all() assert m.dryers.all() == [] - assert m.dryers.single() is None + single_dryer = m.driver.single() + assert single_dryer is None +@mark_sync_test def test_cardinality_zero_or_one(): m = Monkey(name="bob").save() assert m.driver.all() == [] + single_driver = m.driver.single() assert m.driver.single() is None h = ScrewDriver(version=1).save() m.driver.connect(h) assert len(m.driver.all()) == 1 - assert m.driver.single().version == 1 + single_driver = m.driver.single() + assert single_driver.version == 1 j = ScrewDriver(version=2).save() with raises(AttemptedCardinalityViolation): m.driver.connect(j) m.driver.reconnect(h, j) - assert m.driver.single().version == 2 + single_driver = m.driver.single() + assert single_driver.version == 2 # Forcing creation of a second ToothBrush to go around # AttemptedCardinalityViolation - adb.cypher_query( + db.cypher_query( """ MATCH (m:Monkey WHERE m.name="bob") CREATE (s:ScrewDriver {version:3}) @@ -98,6 +109,7 @@ def test_cardinality_zero_or_one(): m.driver.all() +@mark_sync_test def test_cardinality_one_or_more(): m = Monkey(name="jerry").save() @@ -109,7 +121,8 @@ def test_cardinality_one_or_more(): c = Car(version=2).save() m.car.connect(c) - assert m.car.single().version == 2 + single_car = m.car.single() + assert single_car.version == 2 cars = m.car.all() assert len(cars) == 1 @@ -127,6 +140,7 @@ def test_cardinality_one_or_more(): assert len(cars) == 1 +@mark_sync_test def test_cardinality_one(): m = Monkey(name="jerry").save() @@ -140,9 +154,10 @@ def test_cardinality_one(): b = ToothBrush(name="Jim").save() m.toothbrush.connect(b) - assert m.toothbrush.single().name == "Jim" + single_toothbrush = m.toothbrush.single() + assert single_toothbrush.name == "Jim" - x = ToothBrush(name="Jim").save + x = ToothBrush(name="Jim").save() with raises(AttemptedCardinalityViolation): m.toothbrush.connect(x) @@ -154,7 +169,7 @@ def test_cardinality_one(): # Forcing creation of a second ToothBrush to go around # AttemptedCardinalityViolation - adb.cypher_query( + db.cypher_query( """ MATCH (m:Monkey WHERE m.name="jerry") CREATE (t:ToothBrush {name:"Jim"}) diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index 5cd431d8..d49eeae6 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -6,7 +6,7 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNode +from neomodel import StructuredNode, StringProperty from neomodel.sync_.core import db From 7559516da1fd3873921763fcc372b0796b361349 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 29 Dec 2023 17:04:50 +0100 Subject: [PATCH 21/73] More tests --- test/async_/test_connection.py | 151 +++++++++++++++++++++++++++++ test/{ => sync}/test_connection.py | 66 ++++++++----- 2 files changed, 191 insertions(+), 26 deletions(-) create mode 100644 test/async_/test_connection.py rename test/{ => sync}/test_connection.py (66%) diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py new file mode 100644 index 00000000..b82ca0f4 --- /dev/null +++ b/test/async_/test_connection.py @@ -0,0 +1,151 @@ +import os + +from test._async_compat import mark_async_test +from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME +import pytest +from neo4j import AsyncGraphDatabase, AsyncDriver +from neo4j.debug import watch + +from neomodel import AsyncStructuredNode, StringProperty, config +from neomodel.async_.core import adb + + +@mark_async_test +@pytest.fixture(autouse=True) +async def setup_teardown(): + yield + # Teardown actions after tests have run + # Reconnect to initial URL for potential subsequent tests + await adb.close_connection() + await adb.set_connection(url=config.DATABASE_URL) + + +@pytest.fixture(autouse=True, scope="session") +def neo4j_logging(): + with watch("neo4j"): + yield + + +@mark_async_test +async def get_current_database_name() -> str: + """ + Fetches the name of the currently active database from the Neo4j database. + + Returns: + - str: The name of the current database. + """ + results, meta = await adb.cypher_query("CALL db.info") + results_as_dict = [dict(zip(meta, row)) for row in results] + + return results_as_dict[0]["name"] + + +class Pastry(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + +@mark_async_test +async def test_set_connection_driver_works(): + # Verify that current connection is up + assert await Pastry(name="Chocolatine").save() + await adb.close_connection() + + # Test connection using a driver + await adb.set_connection( + driver=AsyncGraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) + ) + assert await Pastry(name="Croissant").save() + + +@mark_async_test +async def test_config_driver_works(): + # Verify that current connection is up + assert await Pastry(name="Chausson aux pommes").save() + await adb.close_connection() + + # Test connection using a driver defined in config + driver: AsyncDriver = AsyncGraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) + + config.DRIVER = driver + assert await Pastry(name="Grignette").save() + + # Clear config + # No need to close connection - pytest teardown will do it + config.DRIVER = None + + +@mark_async_test +@pytest.mark.skipif( + adb.database_edition != "enterprise", + reason="Skipping test for community edition - no multi database in CE", +) +async def test_connect_to_non_default_database(): + database_name = "pastries" + await adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") + await adb.close_connection() + + # Set database name in url - for url init only + await adb.set_connection(url=f"{config.DATABASE_URL}/{database_name}") + assert await get_current_database_name() == "pastries" + + await adb.close_connection() + + # Set database name in config - for both url and driver init + config.DATABASE_NAME = database_name + + # url init + await adb.set_connection(url=config.DATABASE_URL) + assert await get_current_database_name() == "pastries" + + await adb.close_connection() + + # driver init + await adb.set_connection( + driver=AsyncGraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) + ) + assert await get_current_database_name() == "pastries" + + # Clear config + # No need to close connection - pytest teardown will do it + config.DATABASE_NAME = None + + +@mark_async_test +@pytest.mark.parametrize( + "url", ["bolt://user:password", "http://user:password@localhost:7687"] +) +async def test_wrong_url_format(url): + with pytest.raises( + ValueError, + match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}", + ): + await adb.set_connection(url=url) + + +@mark_async_test +@pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) +async def test_connect_to_aura(protocol): + cypher_return = "hello world" + default_cypher_query = f"RETURN '{cypher_return}'" + await adb.close_connection() + + await _set_connection(protocol=protocol) + result, _ = await adb.cypher_query(default_cypher_query) + + assert len(result) > 0 + assert result[0][0] == cypher_return + + +async def _set_connection(protocol): + AURA_TEST_DB_USER = os.environ["AURA_TEST_DB_USER"] + AURA_TEST_DB_PASSWORD = os.environ["AURA_TEST_DB_PASSWORD"] + AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] + + database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" + await adb.set_connection(url=database_url) diff --git a/test/test_connection.py b/test/sync/test_connection.py similarity index 66% rename from test/test_connection.py rename to test/sync/test_connection.py index 4bb0f091..4f4e57c7 100644 --- a/test/test_connection.py +++ b/test/sync/test_connection.py @@ -1,21 +1,23 @@ import os -import time +from test._async_compat import mark_sync_test +from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME import pytest -from neo4j import GraphDatabase +from neo4j import GraphDatabase, Driver from neo4j.debug import watch -from neomodel import AsyncStructuredNode, StringProperty, adb, config -from neomodel.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME +from neomodel import StructuredNode, StringProperty, config +from neomodel.sync_.core import db +@mark_sync_test @pytest.fixture(autouse=True) def setup_teardown(): yield # Teardown actions after tests have run # Reconnect to initial URL for potential subsequent tests - adb.close_connection() - adb.set_connection(url=config.DATABASE_URL) + db.close_connection() + db.set_connection(url=config.DATABASE_URL) @pytest.fixture(autouse=True, scope="session") @@ -24,6 +26,7 @@ def neo4j_logging(): yield +@mark_sync_test def get_current_database_name() -> str: """ Fetches the name of the currently active database from the Neo4j database. @@ -31,35 +34,41 @@ def get_current_database_name() -> str: Returns: - str: The name of the current database. """ - results, meta = adb.cypher_query("CALL db.info") + results, meta = db.cypher_query("CALL db.info") results_as_dict = [dict(zip(meta, row)) for row in results] return results_as_dict[0]["name"] -class Pastry(AsyncStructuredNode): +class Pastry(StructuredNode): name = StringProperty(unique_index=True) +@mark_sync_test def test_set_connection_driver_works(): # Verify that current connection is up assert Pastry(name="Chocolatine").save() - adb.close_connection() + db.close_connection() # Test connection using a driver - adb.set_connection( - driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + db.set_connection( + driver=GraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) ) assert Pastry(name="Croissant").save() +@mark_sync_test def test_config_driver_works(): # Verify that current connection is up assert Pastry(name="Chausson aux pommes").save() - adb.close_connection() + db.close_connection() # Test connection using a driver defined in config - driver = GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + driver: Driver = GraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) config.DRIVER = driver assert Pastry(name="Grignette").save() @@ -69,33 +78,36 @@ def test_config_driver_works(): config.DRIVER = None +@mark_sync_test @pytest.mark.skipif( - adb.database_edition != "enterprise", + db.database_edition != "enterprise", reason="Skipping test for community edition - no multi database in CE", ) def test_connect_to_non_default_database(): database_name = "pastries" - adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") - adb.close_connection() + db.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") + db.close_connection() # Set database name in url - for url init only - adb.set_connection(url=f"{config.DATABASE_URL}/{database_name}") + db.set_connection(url=f"{config.DATABASE_URL}/{database_name}") assert get_current_database_name() == "pastries" - adb.close_connection() + db.close_connection() # Set database name in config - for both url and driver init config.DATABASE_NAME = database_name # url init - adb.set_connection(url=config.DATABASE_URL) + db.set_connection(url=config.DATABASE_URL) assert get_current_database_name() == "pastries" - adb.close_connection() + db.close_connection() # driver init - adb.set_connection( - driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + db.set_connection( + driver=GraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) ) assert get_current_database_name() == "pastries" @@ -104,6 +116,7 @@ def test_connect_to_non_default_database(): config.DATABASE_NAME = None +@mark_sync_test @pytest.mark.parametrize( "url", ["bolt://user:password", "http://user:password@localhost:7687"] ) @@ -112,17 +125,18 @@ def test_wrong_url_format(url): ValueError, match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}", ): - adb.set_connection(url=url) + db.set_connection(url=url) +@mark_sync_test @pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) def test_connect_to_aura(protocol): cypher_return = "hello world" default_cypher_query = f"RETURN '{cypher_return}'" - adb.close_connection() + db.close_connection() _set_connection(protocol=protocol) - result, _ = adb.cypher_query(default_cypher_query) + result, _ = db.cypher_query(default_cypher_query) assert len(result) > 0 assert result[0][0] == cypher_return @@ -134,4 +148,4 @@ def _set_connection(protocol): AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" - adb.set_connection(url=database_url) + db.set_connection(url=database_url) From 9fd375aee45efe59959446aa1636dfa126417d6f Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 2 Jan 2024 09:59:18 +0100 Subject: [PATCH 22/73] More tests --- neomodel/async_/relationship_manager.py | 3 +- neomodel/contrib/sync_/semi_structured.py | 2 +- neomodel/sync_/cardinality.py | 2 +- neomodel/sync_/core.py | 2 +- neomodel/sync_/match.py | 2 +- neomodel/sync_/relationship.py | 4 +- neomodel/sync_/relationship_manager.py | 12 ++-- test/async_/test_database_management.py | 79 +++++++++++++++++++++ test/{ => async_}/test_dbms_awareness.py | 4 +- test/{ => sync}/test_database_management.py | 53 +++++++------- test/sync/test_dbms_awareness.py | 32 +++++++++ test/test_driver_options.py | 50 ------------- 12 files changed, 153 insertions(+), 92 deletions(-) create mode 100644 test/async_/test_database_management.py rename test/{ => async_}/test_dbms_awareness.py (92%) rename test/{ => sync}/test_database_management.py (50%) create mode 100644 test/sync/test_dbms_awareness.py delete mode 100644 test/test_driver_options.py diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 7e4e421a..35d378ba 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -133,7 +133,8 @@ async def connect(self, node, properties=None): await self.source.cypher(q, params) return True - rel_ = await self.source.cypher(q + " RETURN r", params)[0][0][0] + results = await self.source.cypher(q + " RETURN r", params) + rel_ = results[0][0][0] rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node) if hasattr(rel_instance, "post_save"): diff --git a/neomodel/contrib/sync_/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py index dcf7873a..86a5a140 100644 --- a/neomodel/contrib/sync_/semi_structured.py +++ b/neomodel/contrib/sync_/semi_structured.py @@ -1,5 +1,5 @@ -from neomodel.sync_.core import StructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict +from neomodel.sync_.core import StructuredNode from neomodel.util import _get_node_properties diff --git a/neomodel/sync_/cardinality.py b/neomodel/sync_/cardinality.py index bf00f78d..716d173f 100644 --- a/neomodel/sync_/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,8 +1,8 @@ +from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) -from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation class ZeroOrOne(RelationshipManager): diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 8e17bdde..20697612 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -22,7 +22,6 @@ from neo4j.graph import Node, Path, Relationship from neomodel import config -from neomodel.sync_.property_manager import PropertyManager from neomodel.exceptions import ( ConstraintValidationFailed, DoesNotExist, @@ -34,6 +33,7 @@ ) from neomodel.hooks import hooks from neomodel.properties import Property +from neomodel.sync_.property_manager import PropertyManager from neomodel.util import ( _get_node_properties, _UnsavedNode, diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 0a9c74b6..0609aa05 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -4,10 +4,10 @@ from dataclasses import dataclass from typing import Optional -from neomodel.sync_.core import StructuredNode, db from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase from neomodel.properties import AliasProperty +from neomodel.sync_.core import StructuredNode, db from neomodel.util import INCOMING, OUTGOING diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 1845b534..c246df38 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,7 +1,7 @@ -from neomodel.sync_.core import db -from neomodel.sync_.property_manager import PropertyManager from neomodel.hooks import hooks from neomodel.properties import Property +from neomodel.sync_.core import db +from neomodel.sync_.property_manager import PropertyManager class RelationshipMeta(type): diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 887b7b9d..54dcb2be 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -3,15 +3,10 @@ import sys from importlib import import_module +from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.sync_.core import db -from neomodel.sync_.match import ( - NodeSet, - Traversal, - _rel_helper, - _rel_merge_helper, -) +from neomodel.sync_.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper from neomodel.sync_.relationship import StructuredRel -from neomodel.exceptions import NotConnected, RelationshipClassRedefined from neomodel.util import ( EITHER, INCOMING, @@ -133,7 +128,8 @@ def connect(self, node, properties=None): self.source.cypher(q, params) return True - rel_ = self.source.cypher(q + " RETURN r", params)[0][0][0] + results = self.source.cypher(q + " RETURN r", params) + rel_ = results[0][0][0] rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node) if hasattr(rel_instance, "post_save"): diff --git a/test/async_/test_database_management.py b/test/async_/test_database_management.py new file mode 100644 index 00000000..6d2ace9f --- /dev/null +++ b/test/async_/test_database_management.py @@ -0,0 +1,79 @@ +import pytest +from test._async_compat import mark_async_test +from neo4j.exceptions import AuthError + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, +) +from neomodel.async_.core import adb + + +class City(AsyncStructuredNode): + name = StringProperty() + + +class InCity(AsyncStructuredRel): + creation_year = IntegerProperty(index=True) + + +class Venue(AsyncStructuredNode): + name = StringProperty(unique_index=True) + creator = StringProperty(index=True) + in_city = AsyncRelationshipTo(City, relation_type="IN", model=InCity) + + +@mark_async_test +async def test_clear_database(): + venue = await Venue(name="Royal Albert Hall", creator="Queen Victoria").save() + city = await City(name="London").save() + await venue.in_city.connect(city) + + # Clear only the data + await adb.clear_neo4j_database() + database_is_populated, _ = await adb.cypher_query( + "MATCH (a) return count(a)>0 as database_is_populated" + ) + + assert database_is_populated[0][0] is False + + indexes = await adb.list_indexes(exclude_token_lookup=True) + constraints = await adb.list_constraints() + assert len(indexes) > 0 + assert len(constraints) > 0 + + # Clear constraints and indexes too + await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + + indexes = await adb.list_indexes(exclude_token_lookup=True) + constraints = await adb.list_constraints() + assert len(indexes) == 0 + assert len(constraints) == 0 + + +@mark_async_test +async def test_change_password(): + prev_password = "foobarbaz" + new_password = "newpassword" + prev_url = f"bolt://neo4j:{prev_password}@localhost:7687" + new_url = f"bolt://neo4j:{new_password}@localhost:7687" + + await adb.change_neo4j_password("neo4j", new_password) + await adb.close_connection() + + await adb.set_connection(url=new_url) + await adb.close_connection() + + with pytest.raises(AuthError): + await adb.set_connection(url=prev_url) + + await adb.close_connection() + + await adb.set_connection(url=new_url) + await adb.change_neo4j_password("neo4j", prev_password) + await adb.close_connection() + + await adb.set_connection(url=prev_url) diff --git a/test/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py similarity index 92% rename from test/test_dbms_awareness.py rename to test/async_/test_dbms_awareness.py index dc2bf01b..80e83559 100644 --- a/test/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -7,7 +7,7 @@ @mark.skipif( adb.database_version != "5.7.0", reason="Testing a specific database version" ) -def test_version_awareness(): +async def test_version_awareness(): assert adb.database_version == "5.7.0" assert adb.version_is_higher_than("5.7") assert adb.version_is_higher_than("5.6.0") @@ -17,7 +17,7 @@ def test_version_awareness(): assert not adb.version_is_higher_than("5.8") -def test_edition_awareness(): +async def test_edition_awareness(): if adb.database_edition == "enterprise": assert adb.edition_is_enterprise() else: diff --git a/test/test_database_management.py b/test/sync/test_database_management.py similarity index 50% rename from test/test_database_management.py rename to test/sync/test_database_management.py index af1f6d2e..9b3f8bf2 100644 --- a/test/test_database_management.py +++ b/test/sync/test_database_management.py @@ -1,76 +1,79 @@ import pytest +from test._async_compat import mark_sync_test from neo4j.exceptions import AuthError from neomodel import ( - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, + RelationshipTo, + StructuredNode, + StructuredRel, IntegerProperty, StringProperty, ) -from neomodel.async_.core import adb +from neomodel.sync_.core import db -class City(AsyncStructuredNode): +class City(StructuredNode): name = StringProperty() -class InCity(AsyncStructuredRel): +class InCity(StructuredRel): creation_year = IntegerProperty(index=True) -class Venue(AsyncStructuredNode): +class Venue(StructuredNode): name = StringProperty(unique_index=True) creator = StringProperty(index=True) - in_city = AsyncRelationshipTo(City, relation_type="IN", model=InCity) + in_city = RelationshipTo(City, relation_type="IN", model=InCity) +@mark_sync_test def test_clear_database(): venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save() city = City(name="London").save() venue.in_city.connect(city) # Clear only the data - adb.clear_neo4j_database() - database_is_populated, _ = adb.cypher_query( + db.clear_neo4j_database() + database_is_populated, _ = db.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) assert database_is_populated[0][0] is False - indexes = adb.list_indexes(exclude_token_lookup=True) - constraints = adb.list_constraints() + indexes = db.list_indexes(exclude_token_lookup=True) + constraints = db.list_constraints() assert len(indexes) > 0 assert len(constraints) > 0 # Clear constraints and indexes too - adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + db.clear_neo4j_database(clear_constraints=True, clear_indexes=True) - indexes = adb.list_indexes(exclude_token_lookup=True) - constraints = adb.list_constraints() + indexes = db.list_indexes(exclude_token_lookup=True) + constraints = db.list_constraints() assert len(indexes) == 0 assert len(constraints) == 0 +@mark_sync_test def test_change_password(): prev_password = "foobarbaz" new_password = "newpassword" prev_url = f"bolt://neo4j:{prev_password}@localhost:7687" new_url = f"bolt://neo4j:{new_password}@localhost:7687" - adb.change_neo4j_password("neo4j", new_password) - adb.close_connection() + db.change_neo4j_password("neo4j", new_password) + db.close_connection() - adb.set_connection(url=new_url) - adb.close_connection() + db.set_connection(url=new_url) + db.close_connection() with pytest.raises(AuthError): - adb.set_connection(url=prev_url) + db.set_connection(url=prev_url) - adb.close_connection() + db.close_connection() - adb.set_connection(url=new_url) - adb.change_neo4j_password("neo4j", prev_password) - adb.close_connection() + db.set_connection(url=new_url) + db.change_neo4j_password("neo4j", prev_password) + db.close_connection() - adb.set_connection(url=prev_url) + db.set_connection(url=prev_url) diff --git a/test/sync/test_dbms_awareness.py b/test/sync/test_dbms_awareness.py new file mode 100644 index 00000000..d8ef2c9e --- /dev/null +++ b/test/sync/test_dbms_awareness.py @@ -0,0 +1,32 @@ +from pytest import mark + +from neomodel.sync_.core import db +from neomodel.util import version_tag_to_integer + + +@mark.skipif( + db.database_version != "5.7.0", reason="Testing a specific database version" +) +def test_version_awareness(): + assert db.database_version == "5.7.0" + assert db.version_is_higher_than("5.7") + assert db.version_is_higher_than("5.6.0") + assert db.version_is_higher_than("5") + assert db.version_is_higher_than("4") + + assert not db.version_is_higher_than("5.8") + + +def test_edition_awareness(): + if db.database_edition == "enterprise": + assert db.edition_is_enterprise() + else: + assert not db.edition_is_enterprise() + + +def test_version_tag_to_integer(): + assert version_tag_to_integer("5.7.1") == 50701 + assert version_tag_to_integer("5.1") == 50100 + assert version_tag_to_integer("5") == 50000 + assert version_tag_to_integer("5.14.1") == 51401 + assert version_tag_to_integer("5.14-aura") == 51400 diff --git a/test/test_driver_options.py b/test/test_driver_options.py deleted file mode 100644 index 12123931..00000000 --- a/test/test_driver_options.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest -from neo4j.exceptions import ClientError -from pytest import raises - -from neomodel.async_.core import adb -from neomodel.exceptions import FeatureNotSupported - - -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) -def test_impersonate(): - with adb.impersonate(user="troygreene"): - results, _ = adb.cypher_query("RETURN 'Doo Wacko !'") - assert results[0][0] == "Doo Wacko !" - - -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) -def test_impersonate_unauthorized(): - with adb.impersonate(user="unknownuser"): - with raises(ClientError): - _ = adb.cypher_query("RETURN 'Gabagool'") - - -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) -def test_impersonate_multiple_transactions(): - with adb.impersonate(user="troygreene"): - with adb.transaction: - results, _ = adb.cypher_query("RETURN 'Doo Wacko !'") - assert results[0][0] == "Doo Wacko !" - - with adb.transaction: - results, _ = adb.cypher_query("SHOW CURRENT USER") - assert results[0][0] == "troygreene" - - results, _ = adb.cypher_query("SHOW CURRENT USER") - assert results[0][0] == "neo4j" - - -@pytest.mark.skipif( - adb.edition_is_enterprise(), reason="Skipping test for enterprise edition" -) -def test_impersonate_community(): - with raises(FeatureNotSupported): - with adb.impersonate(user="troygreene"): - _ = adb.cypher_query("RETURN 'Gabagoogoo'") From b17f868546f51a55eb6925070650fe7e089dabba Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 2 Jan 2024 10:07:33 +0100 Subject: [PATCH 23/73] Fix non-preserved order --- test/async_/test_cypher.py | 4 ++-- test/sync/test_cypher.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index eee735ea..c13e0b2a 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -146,10 +146,10 @@ async def test_numpy_integration(): array = to_ndarray( await adb.cypher_query( - "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" + "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email ORDER BY name" ) ) assert isinstance(array, ndarray) assert array.shape == (2, 2) - assert array[0][0] == "jimly" + assert array[0][0] == "jimlu" diff --git a/test/sync/test_cypher.py b/test/sync/test_cypher.py index d49eeae6..ac1b026c 100644 --- a/test/sync/test_cypher.py +++ b/test/sync/test_cypher.py @@ -146,10 +146,10 @@ def test_numpy_integration(): array = to_ndarray( db.cypher_query( - "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email" + "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email ORDER BY name" ) ) assert isinstance(array, ndarray) assert array.shape == (2, 2) - assert array[0][0] == "jimly" + assert array[0][0] == "jimlu" From db05e4f7cf17dc63c911b6944846124d774166bf Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 2 Jan 2024 11:22:59 +0100 Subject: [PATCH 24/73] More tests --- neomodel/async_/core.py | 3 +- neomodel/sync_/core.py | 1 + test/async_/test_dbms_awareness.py | 6 +- test/async_/test_driver_options.py | 55 +++ test/{ => async_}/test_exceptions.py | 6 +- test/{ => async_}/test_hooks.py | 8 +- test/async_/test_indexing.py | 89 +++++ test/async_/test_issue112.py | 18 + test/async_/test_issue283.py | 522 +++++++++++++++++++++++++++ test/async_/test_issue600.py | 86 +++++ test/sync/test_dbms_awareness.py | 2 + test/sync/test_driver_options.py | 55 +++ test/sync/test_exceptions.py | 31 ++ test/sync/test_hooks.py | 34 ++ test/{ => sync}/test_indexing.py | 27 +- test/sync/test_issue112.py | 18 + test/{ => sync}/test_issue283.py | 217 +++++++---- test/{ => sync}/test_issue600.py | 41 +-- test/test_issue112.py | 16 - test/test_label_drop.py | 46 --- 20 files changed, 1113 insertions(+), 168 deletions(-) create mode 100644 test/async_/test_driver_options.py rename test/{ => async_}/test_exceptions.py (84%) rename test/{ => async_}/test_hooks.py (81%) create mode 100644 test/async_/test_indexing.py create mode 100644 test/async_/test_issue112.py create mode 100644 test/async_/test_issue283.py create mode 100644 test/async_/test_issue600.py create mode 100644 test/sync/test_driver_options.py create mode 100644 test/sync/test_exceptions.py create mode 100644 test/sync/test_hooks.py rename test/{ => sync}/test_indexing.py (77%) create mode 100644 test/sync/test_issue112.py rename test/{ => sync}/test_issue283.py (68%) rename test/{ => sync}/test_issue600.py (60%) delete mode 100644 test/test_issue112.py delete mode 100644 test/test_label_drop.py diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index c99ce260..42890261 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -554,7 +554,8 @@ def version_is_higher_than(self, version_tag: str) -> bool: version_tag ) - def edition_is_enterprise(self) -> bool: + @ensure_connection + async def edition_is_enterprise(self) -> bool: """Returns true if the database edition is enterprise Returns: diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 20697612..8b0e1a7e 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -552,6 +552,7 @@ def version_is_higher_than(self, version_tag: str) -> bool: version_tag ) + @ensure_connection def edition_is_enterprise(self) -> bool: """Returns true if the database edition is enterprise diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py index 80e83559..7399666a 100644 --- a/test/async_/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -1,4 +1,5 @@ from pytest import mark +from test._async_compat import mark_async_test from neomodel.async_.core import adb from neomodel.util import version_tag_to_integer @@ -17,11 +18,12 @@ async def test_version_awareness(): assert not adb.version_is_higher_than("5.8") +@mark_async_test async def test_edition_awareness(): if adb.database_edition == "enterprise": - assert adb.edition_is_enterprise() + assert await adb.edition_is_enterprise() else: - assert not adb.edition_is_enterprise() + assert not await adb.edition_is_enterprise() def test_version_tag_to_integer(): diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py new file mode 100644 index 00000000..64d5b85d --- /dev/null +++ b/test/async_/test_driver_options.py @@ -0,0 +1,55 @@ +import pytest +from test._async_compat import mark_async_test +from neo4j.exceptions import ClientError +from pytest import raises + +from neomodel.async_.core import adb +from neomodel.exceptions import FeatureNotSupported + + +@mark_async_test +@pytest.mark.skipif( + not adb.edition_is_enterprise(), reason="Skipping test for community edition" +) +async def test_impersonate(): + with adb.impersonate(user="troygreene"): + results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + +@mark_async_test +@pytest.mark.skipif( + not adb.edition_is_enterprise(), reason="Skipping test for community edition" +) +async def test_impersonate_unauthorized(): + with adb.impersonate(user="unknownuser"): + with raises(ClientError): + _ = await adb.cypher_query("RETURN 'Gabagool'") + + +@mark_async_test +@pytest.mark.skipif( + not adb.edition_is_enterprise(), reason="Skipping test for community edition" +) +async def test_impersonate_multiple_transactions(): + with adb.impersonate(user="troygreene"): + with adb.transaction: + results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + with adb.transaction: + results, _ = await adb.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "troygreene" + + results, _ = await adb.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "neo4j" + + +@mark_async_test +@pytest.mark.skipif( + adb.edition_is_enterprise(), reason="Skipping test for enterprise edition" +) +async def test_impersonate_community(): + with raises(FeatureNotSupported): + with adb.impersonate(user="troygreene"): + _ = await adb.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/test_exceptions.py b/test/async_/test_exceptions.py similarity index 84% rename from test/test_exceptions.py rename to test/async_/test_exceptions.py index c6976515..e948f76c 100644 --- a/test/test_exceptions.py +++ b/test/async_/test_exceptions.py @@ -1,4 +1,5 @@ import pickle +from test._async_compat import mark_async_test from neomodel import AsyncStructuredNode, DoesNotExist, StringProperty @@ -7,9 +8,10 @@ class EPerson(AsyncStructuredNode): name = StringProperty(unique_index=True) -def test_object_does_not_exist(): +@mark_async_test +async def test_object_does_not_exist(): try: - EPerson.nodes.get(name="johnny") + await EPerson.nodes.get(name="johnny") except EPerson.DoesNotExist as e: pickle_instance = pickle.dumps(e) assert pickle_instance diff --git a/test/test_hooks.py b/test/async_/test_hooks.py similarity index 81% rename from test/test_hooks.py rename to test/async_/test_hooks.py index 8fb9b8e5..12643e78 100644 --- a/test/test_hooks.py +++ b/test/async_/test_hooks.py @@ -1,3 +1,4 @@ +from test._async_compat import mark_async_test from neomodel import AsyncStructuredNode, StringProperty HOOKS_CALLED = {} @@ -22,9 +23,10 @@ def post_delete(self): HOOKS_CALLED["post_delete"] = 1 -def test_hooks(): - ht = HookTest(name="k").save() - ht.delete() +@mark_async_test +async def test_hooks(): + ht = await HookTest(name="k").save() + await ht.delete() assert "pre_save" in HOOKS_CALLED assert "post_save" in HOOKS_CALLED assert "post_create" in HOOKS_CALLED diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py new file mode 100644 index 00000000..6fd51488 --- /dev/null +++ b/test/async_/test_indexing.py @@ -0,0 +1,89 @@ +import pytest +from pytest import raises +from test._async_compat import mark_async_test + +from neomodel import ( + AsyncStructuredNode, + IntegerProperty, + StringProperty, + UniqueProperty, +) +from neomodel.async_.core import adb +from neomodel.exceptions import ConstraintValidationFailed + + +class Human(AsyncStructuredNode): + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True) + + +@mark_async_test +async def test_unique_error(): + await adb.install_labels(Human) + await Human(name="j1m", age=13).save() + try: + await Human(name="j1m", age=14).save() + except UniqueProperty as e: + assert str(e).find("j1m") + assert str(e).find("name") + else: + assert False, "UniqueProperty not raised." + + +@mark_async_test +@pytest.mark.skipif( + not adb.edition_is_enterprise(), reason="Skipping test for community edition" +) +async def test_existence_constraint_error(): + await adb.cypher_query( + "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" + ) + with raises(ConstraintValidationFailed, match=r"must have the property"): + await Human(name="Scarlett").save() + + await adb.cypher_query("DROP CONSTRAINT test_existence_constraint") + + +@mark_async_test +async def test_optional_properties_dont_get_indexed(): + await Human(name="99", age=99).save() + h = await Human.nodes.get(age=99) + assert h + assert h.name == "99" + + await Human(name="98", age=98).save() + h = await Human.nodes.get(age=98) + assert h + assert h.name == "98" + + +@mark_async_test +async def test_escaped_chars(): + _name = "sarah:test" + await Human(name=_name, age=3).save() + r = Human.nodes.filter(name=_name) + first_r = await r[0] + assert first_r.name == _name + + +@mark_async_test +async def test_does_not_exist(): + with raises(Human.DoesNotExist): + await Human.nodes.get(name="XXXX") + + +@mark_async_test +async def test_custom_label_name(): + class Giraffe(AsyncStructuredNode): + __label__ = "Giraffes" + name = StringProperty(unique_index=True) + + jim = await Giraffe(name="timothy").save() + node = await Giraffe.nodes.get(name="timothy") + assert node.name == jim.name + + class SpecialGiraffe(Giraffe): + power = StringProperty() + + # custom labels aren't inherited + assert SpecialGiraffe.__label__ == "SpecialGiraffe" diff --git a/test/async_/test_issue112.py b/test/async_/test_issue112.py new file mode 100644 index 00000000..12940992 --- /dev/null +++ b/test/async_/test_issue112.py @@ -0,0 +1,18 @@ +from test._async_compat import mark_async_test +from neomodel import AsyncRelationshipTo, AsyncStructuredNode + + +class SomeModel(AsyncStructuredNode): + test = AsyncRelationshipTo("SomeModel", "SELF") + + +@mark_async_test +async def test_len_relationship(): + t1 = await SomeModel().save() + t2 = await SomeModel().save() + + await t1.test.connect(t2) + l = len(await t1.test.all()) + + assert l + assert l == 1 diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py new file mode 100644 index 00000000..7444f0a4 --- /dev/null +++ b/test/async_/test_issue283.py @@ -0,0 +1,522 @@ +""" +Provides a test case for issue 283 - "Inheritance breaks". + +The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/283 +More information about the same issue at: +https://github.com/aanastasiou/neomodelInheritanceTest + +The following example uses a recursive relationship for economy, but the +idea remains the same: "Instantiate the correct type of node at the end of +a relationship as specified by the model" +""" +from test._async_compat import mark_async_test +import random + +import pytest + +from neomodel import ( + AsyncStructuredRel, + AsyncStructuredNode, + DateTimeProperty, + FloatProperty, + StringProperty, + UniqueIdProperty, + AsyncRelationshipTo, + RelationshipClassRedefined, + RelationshipClassNotDefined, +) +from neomodel.async_.core import adb +from neomodel.exceptions import NodeClassNotDefined, NodeClassAlreadyDefined + +try: + basestring +except NameError: + basestring = str + + +# Set up a very simple model for the tests +class PersonalRelationship(AsyncStructuredRel): + """ + A very simple relationship between two basePersons that simply records + the date at which an acquaintance was established. + This relationship should be carried over to anything that inherits from + basePerson without any further effort. + """ + + on_date = DateTimeProperty(default_now=True) + + +class BasePerson(AsyncStructuredNode): + """ + Base class for defining some basic sort of an actor. + """ + + name = StringProperty(required=True, unique_index=True) + friends_with = AsyncRelationshipTo( + "BasePerson", "FRIENDS_WITH", model=PersonalRelationship + ) + + +class TechnicalPerson(BasePerson): + """ + A Technical person specialises BasePerson by adding their expertise. + """ + + expertise = StringProperty(required=True) + + +class PilotPerson(BasePerson): + """ + A pilot person specialises BasePerson by adding the type of airplane they + can operate. + """ + + airplane = StringProperty(required=True) + + +class BaseOtherPerson(AsyncStructuredNode): + """ + An obviously "wrong" class of actor to befriend BasePersons with. + """ + + car_color = StringProperty(required=True) + + +class SomePerson(BaseOtherPerson): + """ + Concrete class that simply derives from BaseOtherPerson. + """ + + pass + + +# Test cases +@mark_async_test +async def test_automatic_result_resolution(): + """ + Node objects at the end of relationships are instantiated to their + corresponding Python object. + """ + + # Create a few entities + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] + + # Add connections + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(A) + + # If A is friends with B, then A's friends_with objects should be + # TechnicalPerson (!NOT basePerson!) + assert type(await A.friends_with[0]) is TechnicalPerson + + await A.delete() + await B.delete() + await C.delete() + + +@mark_async_test +async def test_recursive_automatic_result_resolution(): + """ + Node objects are instantiated to native Python objects, both at the top + level of returned results and in the case where they are returned within + lists. + """ + + # Create a few entities + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpier", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create( + {"name": "Happier", "expertise": "Grumpiness"} + ) + )[0] + C = ( + await TechnicalPerson.get_or_create( + {"name": "Sleepier", "expertise": "Pillows"} + ) + )[0] + D = ( + await TechnicalPerson.get_or_create( + {"name": "Sneezier", "expertise": "Pillows"} + ) + )[0] + + # Retrieve mixed results, both at the top level and nested + L, _ = await adb.cypher_query( + "MATCH (a:TechnicalPerson) " + "WHERE a.expertise='Grumpiness' " + "WITH collect(a) as Alpha " + "MATCH (b:TechnicalPerson) " + "WHERE b.expertise='Pillows' " + "WITH Alpha, collect(b) as Beta " + "RETURN [Alpha, [Beta, [Beta, ['Banana', " + "Alpha]]]]", + resolve_objects=True, + ) + + # Assert that a Node returned deep in a nested list structure is of the + # correct type + assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson + # Assert that primitive data types remain primitive data types + assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) + + await A.delete() + await B.delete() + await C.delete() + await D.delete() + + +@mark_async_test +async def test_validation_with_inheritance_from_db(): + """ + Objects descending from the specified class of a relationship's end-node are + also perfectly valid to appear as end-node values too + """ + + # Create a few entities + # Technical Persons + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] + + # Pilot Persons + D = ( + await PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) + )[0] + E = ( + await PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) + )[0] + + # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine + + # TechnicalPersons befriend Technical Persons + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(A) + + # Pilot Persons befriend Pilot Persons + await D.friends_with.connect(E) + + # Technical Persons befriend Pilot Persons + await A.friends_with.connect(D) + await E.friends_with.connect(C) + + # This now means that friends_with of a TechnicalPerson can + # either be TechnicalPerson or Pilot Person (!NOT basePerson!) + + assert (type(await A.friends_with[0]) is TechnicalPerson) or ( + type(await A.friends_with[0]) is PilotPerson + ) + assert (type(await A.friends_with[1]) is TechnicalPerson) or ( + type(await A.friends_with[1]) is PilotPerson + ) + assert type(await D.friends_with[0]) is PilotPerson + + await A.delete() + await B.delete() + await C.delete() + await D.delete() + await E.delete() + + +@mark_async_test +async def test_validation_enforcement_to_db(): + """ + If a connection between wrong types is attempted, raise an exception + """ + + # Create a few entities + # Technical Persons + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] + + # Pilot Persons + D = ( + await PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) + )[0] + E = ( + await PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) + )[0] + + # Some Person + F = await SomePerson(car_color="Blue").save() + + # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(A) + await D.friends_with.connect(E) + await A.friends_with.connect(D) + await E.friends_with.connect(C) + + # Trying to befriend a Technical Person with Some Person should raise an + # exception + with pytest.raises(ValueError): + await A.friends_with.connect(F) + + await A.delete() + await B.delete() + await C.delete() + await D.delete() + await E.delete() + await F.delete() + + +@mark_async_test +async def test_failed_result_resolution(): + """ + A Neo4j driver node FROM the database contains labels that are unaware to + neomodel's Database class. This condition raises ClassDefinitionNotFound + exception. + """ + + class RandomPerson(BasePerson): + randomness = FloatProperty(default=random.random) + + # A Technical Person... + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + + # A Random Person... + B = (await RandomPerson.get_or_create({"name": "Mad Hatter"}))[0] + + await A.friends_with.connect(B) + + # Simulate the condition where the definition of class RandomPerson is not + # known yet. + del adb._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] + + # Now try to instantiate a RandomPerson + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + with pytest.raises( + NodeClassNotDefined, + match=r"Node with labels .* does not resolve to any of the known objects.*", + ): + friends = await A.friends_with.all() + for some_friend in friends: + print(some_friend.name) + + await A.delete() + await B.delete() + + +@mark_async_test +async def test_node_label_mismatch(): + """ + A Neo4j driver node FROM the database contains a superset of the known + labels. + """ + + class SuperTechnicalPerson(TechnicalPerson): + superness = FloatProperty(default=1.0) + + class UltraTechnicalPerson(SuperTechnicalPerson): + ultraness = FloatProperty(default=3.1415928) + + # Create a TechnicalPerson... + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + # ...that is connected to an UltraTechnicalPerson + F = await UltraTechnicalPerson( + name="Chewbaka", expertise="Aarrr wgh ggwaaah" + ).save() + await A.friends_with.connect(F) + + # Forget about the UltraTechnicalPerson + del adb._NODE_CLASS_REGISTRY[ + frozenset( + [ + "UltraTechnicalPerson", + "SuperTechnicalPerson", + "TechnicalPerson", + "BasePerson", + ] + ) + ] + + # Recall a TechnicalPerson and enumerate its friends. + # One of them is UltraTechnicalPerson which would be returned as a valid + # node to a friends_with query but is currently unknown to the node class registry. + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + with pytest.raises(NodeClassNotDefined): + friends = await A.friends_with.all() + for some_friend in friends: + print(some_friend.name) + + +def test_attempted_class_redefinition(): + """ + A StructuredNode class is attempted to be redefined. + """ + + def redefine_class_locally(): + # Since this test has already set up a class hierarchy in its global scope, we will try to redefine + # SomePerson here. + # The internal structure of the SomePerson entity does not matter at all here. + class SomePerson(BaseOtherPerson): + uid = UniqueIdProperty() + + with pytest.raises( + NodeClassAlreadyDefined, + match=r"Class .* with labels .* already defined:.*", + ): + redefine_class_locally() + + +@mark_async_test +async def test_relationship_result_resolution(): + """ + A query returning a "Relationship" object can now instantiate it to a data model class + """ + # Test specific data + A = await PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save() + B = await PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save() + C = await PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save() + D = await PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save() + E = await PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save() + + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(D) + await D.friends_with.connect(E) + + query_data = await adb.cypher_query( + "MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) " + "WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' " + "RETURN DISTINCT r", + resolve_objects=True, + ) + + # The relationship here should be properly instantiated to a `PersonalRelationship` object. + assert isinstance(query_data[0][0][0], PersonalRelationship) + + +@mark_async_test +async def test_properly_inherited_relationship(): + """ + A relationship class extends an existing relationship model that must extended the same previously associated + relationship label. + """ + + # Extends an existing relationship by adding the "relationship_strength" attribute. + # `ExtendedPersonalRelationship` will now substitute `PersonalRelationship` EVERYWHERE in the system. + class ExtendedPersonalRelationship(PersonalRelationship): + relationship_strength = FloatProperty(default=random.random) + + # Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson + class ExtendedSomePerson(SomePerson): + friends_with = AsyncRelationshipTo( + "BaseOtherPerson", + "FRIENDS_WITH", + model=ExtendedPersonalRelationship, + ) + + # Test specific data + A = await ExtendedSomePerson(name="Michael Knight", car_color="Black").save() + B = await ExtendedSomePerson(name="Luke Duke", car_color="Orange").save() + C = await ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save() + + await A.friends_with.connect(B) + await A.friends_with.connect(C) + + query_data = await adb.cypher_query( + "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " + "RETURN DISTINCT r", + resolve_objects=True, + ) + + assert isinstance(query_data[0][0][0], ExtendedPersonalRelationship) + + +def test_improperly_inherited_relationship(): + """ + Attempting to re-define an existing relationship with a completely unrelated class. + :return: + """ + + class NewRelationship(AsyncStructuredRel): + profile_match_factor = FloatProperty() + + with pytest.raises( + RelationshipClassRedefined, + match=r"Relationship of type .* redefined as .*", + ): + + class NewSomePerson(SomePerson): + friends_with = AsyncRelationshipTo( + "BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship + ) + + +@mark_async_test +async def test_resolve_inexistent_relationship(): + """ + Attempting to resolve an inexistent relationship should raise an exception + :return: + """ + + # Forget about the FRIENDS_WITH Relationship. + del adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] + + with pytest.raises( + RelationshipClassNotDefined, + match=r"Relationship of type .* does not resolve to any of the known objects.*", + ): + query_data = await adb.cypher_query( + "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " + "RETURN DISTINCT r", + resolve_objects=True, + ) diff --git a/test/async_/test_issue600.py b/test/async_/test_issue600.py new file mode 100644 index 00000000..a35fb8f6 --- /dev/null +++ b/test/async_/test_issue600.py @@ -0,0 +1,86 @@ +""" +Provides a test case for issue 600 - "Pull request #592 cause an error in case of relationship inharitance". + +The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/600 +""" + +from test._async_compat import mark_async_test +from neomodel import AsyncStructuredNode, AsyncRelationship, AsyncStructuredRel + +try: + basestring +except NameError: + basestring = str + + +class Class1(AsyncStructuredRel): + pass + + +class SubClass1(Class1): + pass + + +class SubClass2(Class1): + pass + + +class RelationshipDefinerSecondSibling(AsyncStructuredNode): + rel_1 = AsyncRelationship( + "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 + ) + rel_2 = AsyncRelationship( + "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1 + ) + rel_3 = AsyncRelationship( + "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2 + ) + + +class RelationshipDefinerParentLast(AsyncStructuredNode): + rel_2 = AsyncRelationship( + "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 + ) + rel_3 = AsyncRelationship( + "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2 + ) + rel_1 = AsyncRelationship( + "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1 + ) + + +# Test cases +@mark_async_test +async def test_relationship_definer_second_sibling(): + # Create a few entities + A = (await RelationshipDefinerSecondSibling.get_or_create({}))[0] + B = (await RelationshipDefinerSecondSibling.get_or_create({}))[0] + C = (await RelationshipDefinerSecondSibling.get_or_create({}))[0] + + # Add connections + await A.rel_1.connect(B) + await B.rel_2.connect(C) + await C.rel_3.connect(A) + + # Clean up + await A.delete() + await B.delete() + await C.delete() + + +@mark_async_test +async def test_relationship_definer_parent_last(): + # Create a few entities + A = (await RelationshipDefinerParentLast.get_or_create({}))[0] + B = (await RelationshipDefinerParentLast.get_or_create({}))[0] + C = (await RelationshipDefinerParentLast.get_or_create({}))[0] + + # Add connections + await A.rel_1.connect(B) + await B.rel_2.connect(C) + await C.rel_3.connect(A) + + # Clean up + await A.delete() + await B.delete() + await C.delete() diff --git a/test/sync/test_dbms_awareness.py b/test/sync/test_dbms_awareness.py index d8ef2c9e..b2776af0 100644 --- a/test/sync/test_dbms_awareness.py +++ b/test/sync/test_dbms_awareness.py @@ -1,4 +1,5 @@ from pytest import mark +from test._async_compat import mark_sync_test from neomodel.sync_.core import db from neomodel.util import version_tag_to_integer @@ -17,6 +18,7 @@ def test_version_awareness(): assert not db.version_is_higher_than("5.8") +@mark_sync_test def test_edition_awareness(): if db.database_edition == "enterprise": assert db.edition_is_enterprise() diff --git a/test/sync/test_driver_options.py b/test/sync/test_driver_options.py new file mode 100644 index 00000000..cedb1ae0 --- /dev/null +++ b/test/sync/test_driver_options.py @@ -0,0 +1,55 @@ +import pytest +from test._async_compat import mark_sync_test +from neo4j.exceptions import ClientError +from pytest import raises + +from neomodel.sync_.core import db +from neomodel.exceptions import FeatureNotSupported + + +@mark_sync_test +@pytest.mark.skipif( + not db.edition_is_enterprise(), reason="Skipping test for community edition" +) +def test_impersonate(): + with db.impersonate(user="troygreene"): + results, _ = db.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + +@mark_sync_test +@pytest.mark.skipif( + not db.edition_is_enterprise(), reason="Skipping test for community edition" +) +def test_impersonate_unauthorized(): + with db.impersonate(user="unknownuser"): + with raises(ClientError): + _ = db.cypher_query("RETURN 'Gabagool'") + + +@mark_sync_test +@pytest.mark.skipif( + not db.edition_is_enterprise(), reason="Skipping test for community edition" +) +def test_impersonate_multiple_transactions(): + with db.impersonate(user="troygreene"): + with db.transaction: + results, _ = db.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + with db.transaction: + results, _ = db.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "troygreene" + + results, _ = db.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "neo4j" + + +@mark_sync_test +@pytest.mark.skipif( + db.edition_is_enterprise(), reason="Skipping test for enterprise edition" +) +def test_impersonate_community(): + with raises(FeatureNotSupported): + with db.impersonate(user="troygreene"): + _ = db.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/sync/test_exceptions.py b/test/sync/test_exceptions.py new file mode 100644 index 00000000..a422db87 --- /dev/null +++ b/test/sync/test_exceptions.py @@ -0,0 +1,31 @@ +import pickle +from test._async_compat import mark_sync_test + +from neomodel import StructuredNode, DoesNotExist, StringProperty + + +class EPerson(StructuredNode): + name = StringProperty(unique_index=True) + + +@mark_sync_test +def test_object_does_not_exist(): + try: + EPerson.nodes.get(name="johnny") + except EPerson.DoesNotExist as e: + pickle_instance = pickle.dumps(e) + assert pickle_instance + assert pickle.loads(pickle_instance) + assert isinstance(pickle.loads(pickle_instance), DoesNotExist) + else: + assert False, "Person.DoesNotExist not raised." + + +def test_pickle_does_not_exist(): + try: + raise EPerson.DoesNotExist("My Test Message") + except EPerson.DoesNotExist as e: + pickle_instance = pickle.dumps(e) + assert pickle_instance + assert pickle.loads(pickle_instance) + assert isinstance(pickle.loads(pickle_instance), DoesNotExist) diff --git a/test/sync/test_hooks.py b/test/sync/test_hooks.py new file mode 100644 index 00000000..b3cbe864 --- /dev/null +++ b/test/sync/test_hooks.py @@ -0,0 +1,34 @@ +from test._async_compat import mark_sync_test +from neomodel import StructuredNode, StringProperty + +HOOKS_CALLED = {} + + +class HookTest(StructuredNode): + name = StringProperty() + + def post_create(self): + HOOKS_CALLED["post_create"] = 1 + + def pre_save(self): + HOOKS_CALLED["pre_save"] = 1 + + def post_save(self): + HOOKS_CALLED["post_save"] = 1 + + def pre_delete(self): + HOOKS_CALLED["pre_delete"] = 1 + + def post_delete(self): + HOOKS_CALLED["post_delete"] = 1 + + +@mark_sync_test +def test_hooks(): + ht = HookTest(name="k").save() + ht.delete() + assert "pre_save" in HOOKS_CALLED + assert "post_save" in HOOKS_CALLED + assert "post_create" in HOOKS_CALLED + assert "pre_delete" in HOOKS_CALLED + assert "post_delete" in HOOKS_CALLED diff --git a/test/test_indexing.py b/test/sync/test_indexing.py similarity index 77% rename from test/test_indexing.py rename to test/sync/test_indexing.py index 88311679..1eda3d21 100644 --- a/test/test_indexing.py +++ b/test/sync/test_indexing.py @@ -1,23 +1,25 @@ import pytest from pytest import raises +from test._async_compat import mark_sync_test from neomodel import ( - AsyncStructuredNode, + StructuredNode, IntegerProperty, StringProperty, UniqueProperty, ) -from neomodel.async_.core import adb +from neomodel.sync_.core import db from neomodel.exceptions import ConstraintValidationFailed -class Human(AsyncStructuredNode): +class Human(StructuredNode): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) +@mark_sync_test def test_unique_error(): - adb.install_labels(Human) + db.install_labels(Human) Human(name="j1m", age=13).save() try: Human(name="j1m", age=14).save() @@ -28,19 +30,21 @@ def test_unique_error(): assert False, "UniqueProperty not raised." +@mark_sync_test @pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" + not db.edition_is_enterprise(), reason="Skipping test for community edition" ) def test_existence_constraint_error(): - adb.cypher_query( + db.cypher_query( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" ) with raises(ConstraintValidationFailed, match=r"must have the property"): Human(name="Scarlett").save() - adb.cypher_query("DROP CONSTRAINT test_existence_constraint") + db.cypher_query("DROP CONSTRAINT test_existence_constraint") +@mark_sync_test def test_optional_properties_dont_get_indexed(): Human(name="99", age=99).save() h = Human.nodes.get(age=99) @@ -53,21 +57,24 @@ def test_optional_properties_dont_get_indexed(): assert h.name == "98" +@mark_sync_test def test_escaped_chars(): _name = "sarah:test" Human(name=_name, age=3).save() r = Human.nodes.filter(name=_name) - assert r - assert r[0].name == _name + first_r = r[0] + assert first_r.name == _name +@mark_sync_test def test_does_not_exist(): with raises(Human.DoesNotExist): Human.nodes.get(name="XXXX") +@mark_sync_test def test_custom_label_name(): - class Giraffe(AsyncStructuredNode): + class Giraffe(StructuredNode): __label__ = "Giraffes" name = StringProperty(unique_index=True) diff --git a/test/sync/test_issue112.py b/test/sync/test_issue112.py new file mode 100644 index 00000000..26605018 --- /dev/null +++ b/test/sync/test_issue112.py @@ -0,0 +1,18 @@ +from test._async_compat import mark_sync_test +from neomodel import RelationshipTo, StructuredNode + + +class SomeModel(StructuredNode): + test = RelationshipTo("SomeModel", "SELF") + + +@mark_sync_test +def test_len_relationship(): + t1 = SomeModel().save() + t2 = SomeModel().save() + + t1.test.connect(t2) + l = len(t1.test.all()) + + assert l + assert l == 1 diff --git a/test/test_issue283.py b/test/sync/test_issue283.py similarity index 68% rename from test/test_issue283.py rename to test/sync/test_issue283.py index fb5b5f2a..652a0e69 100644 --- a/test/test_issue283.py +++ b/test/sync/test_issue283.py @@ -9,14 +9,24 @@ idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ - -import datetime -import os +from test._async_compat import mark_sync_test import random import pytest -import neomodel +from neomodel import ( + StructuredRel, + StructuredNode, + DateTimeProperty, + FloatProperty, + StringProperty, + UniqueIdProperty, + RelationshipTo, + RelationshipClassRedefined, + RelationshipClassNotDefined, +) +from neomodel.sync_.core import db +from neomodel.exceptions import NodeClassNotDefined, NodeClassAlreadyDefined try: basestring @@ -25,7 +35,7 @@ # Set up a very simple model for the tests -class PersonalRelationship(neomodel.AsyncStructuredRel): +class PersonalRelationship(StructuredRel): """ A very simple relationship between two basePersons that simply records the date at which an acquaintance was established. @@ -33,16 +43,16 @@ class PersonalRelationship(neomodel.AsyncStructuredRel): basePerson without any further effort. """ - on_date = neomodel.DateTimeProperty(default_now=True) + on_date = DateTimeProperty(default_now=True) -class BasePerson(neomodel.AsyncStructuredNode): +class BasePerson(StructuredNode): """ Base class for defining some basic sort of an actor. """ - name = neomodel.StringProperty(required=True, unique_index=True) - friends_with = neomodel.AsyncRelationshipTo( + name = StringProperty(required=True, unique_index=True) + friends_with = RelationshipTo( "BasePerson", "FRIENDS_WITH", model=PersonalRelationship ) @@ -52,7 +62,7 @@ class TechnicalPerson(BasePerson): A Technical person specialises BasePerson by adding their expertise. """ - expertise = neomodel.StringProperty(required=True) + expertise = StringProperty(required=True) class PilotPerson(BasePerson): @@ -61,15 +71,15 @@ class PilotPerson(BasePerson): can operate. """ - airplane = neomodel.StringProperty(required=True) + airplane = StringProperty(required=True) -class BaseOtherPerson(neomodel.AsyncStructuredNode): +class BaseOtherPerson(StructuredNode): """ An obviously "wrong" class of actor to befriend BasePersons with. """ - car_color = neomodel.StringProperty(required=True) + car_color = StringProperty(required=True) class SomePerson(BaseOtherPerson): @@ -81,6 +91,7 @@ class SomePerson(BaseOtherPerson): # Test cases +@mark_sync_test def test_automatic_result_resolution(): """ Node objects at the end of relationships are instantiated to their @@ -88,9 +99,17 @@ def test_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] # Add connections A.friends_with.connect(B) @@ -106,6 +125,7 @@ def test_automatic_result_resolution(): C.delete() +@mark_sync_test def test_recursive_automatic_result_resolution(): """ Node objects are instantiated to native Python objects, both at the top @@ -114,15 +134,29 @@ def test_recursive_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"})[ - 0 - ] - B = TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"})[0] - D = TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpier", "expertise": "Grumpiness"} + ) + )[0] + B = ( + TechnicalPerson.get_or_create( + {"name": "Happier", "expertise": "Grumpiness"} + ) + )[0] + C = ( + TechnicalPerson.get_or_create( + {"name": "Sleepier", "expertise": "Pillows"} + ) + )[0] + D = ( + TechnicalPerson.get_or_create( + {"name": "Sneezier", "expertise": "Pillows"} + ) + )[0] # Retrieve mixed results, both at the top level and nested - L, _ = neomodel.adb.cypher_query( + L, _ = db.cypher_query( "MATCH (a:TechnicalPerson) " "WHERE a.expertise='Grumpiness' " "WITH collect(a) as Alpha " @@ -146,6 +180,7 @@ def test_recursive_automatic_result_resolution(): D.delete() +@mark_sync_test def test_validation_with_inheritance_from_db(): """ Objects descending from the specified class of a relationship's end-node are @@ -154,16 +189,28 @@ def test_validation_with_inheritance_from_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] # Pilot Persons - D = PilotPerson.get_or_create( - {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + D = ( + PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) )[0] - E = PilotPerson.get_or_create( - {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + E = ( + PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) )[0] # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine @@ -198,6 +245,7 @@ def test_validation_with_inheritance_from_db(): E.delete() +@mark_sync_test def test_validation_enforcement_to_db(): """ If a connection between wrong types is attempted, raise an exception @@ -205,16 +253,28 @@ def test_validation_enforcement_to_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - B = TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})[0] - C = TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] # Pilot Persons - D = PilotPerson.get_or_create( - {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + D = ( + PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) )[0] - E = PilotPerson.get_or_create( - {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + E = ( + PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) )[0] # Some Person @@ -241,6 +301,7 @@ def test_validation_enforcement_to_db(): F.delete() +@mark_sync_test def test_failed_result_resolution(): """ A Neo4j driver node FROM the database contains labels that are unaware to @@ -249,33 +310,43 @@ def test_failed_result_resolution(): """ class RandomPerson(BasePerson): - randomness = neomodel.FloatProperty(default=random.random) + randomness = FloatProperty(default=random.random) # A Technical Person... - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] # A Random Person... - B = RandomPerson.get_or_create({"name": "Mad Hatter"})[0] + B = (RandomPerson.get_or_create({"name": "Mad Hatter"}))[0] A.friends_with.connect(B) # Simulate the condition where the definition of class RandomPerson is not # known yet. - del neomodel.adb._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] + del db._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] # Now try to instantiate a RandomPerson - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] with pytest.raises( - neomodel.exceptions.NodeClassNotDefined, + NodeClassNotDefined, match=r"Node with labels .* does not resolve to any of the known objects.*", ): - for some_friend in A.friends_with: + friends = A.friends_with.all() + for some_friend in friends: print(some_friend.name) A.delete() B.delete() +@mark_sync_test def test_node_label_mismatch(): """ A Neo4j driver node FROM the database contains a superset of the known @@ -283,19 +354,25 @@ def test_node_label_mismatch(): """ class SuperTechnicalPerson(TechnicalPerson): - superness = neomodel.FloatProperty(default=1.0) + superness = FloatProperty(default=1.0) class UltraTechnicalPerson(SuperTechnicalPerson): - ultraness = neomodel.FloatProperty(default=3.1415928) + ultraness = FloatProperty(default=3.1415928) # Create a TechnicalPerson... - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] # ...that is connected to an UltraTechnicalPerson - F = UltraTechnicalPerson(name="Chewbaka", expertise="Aarrr wgh ggwaaah").save() + F = UltraTechnicalPerson( + name="Chewbaka", expertise="Aarrr wgh ggwaaah" + ).save() A.friends_with.connect(F) # Forget about the UltraTechnicalPerson - del neomodel.adb._NODE_CLASS_REGISTRY[ + del db._NODE_CLASS_REGISTRY[ frozenset( [ "UltraTechnicalPerson", @@ -309,15 +386,20 @@ class UltraTechnicalPerson(SuperTechnicalPerson): # Recall a TechnicalPerson and enumerate its friends. # One of them is UltraTechnicalPerson which would be returned as a valid # node to a friends_with query but is currently unknown to the node class registry. - A = TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"})[0] - with pytest.raises(neomodel.exceptions.NodeClassNotDefined): - for some_friend in A.friends_with: + A = ( + TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + with pytest.raises(NodeClassNotDefined): + friends = A.friends_with.all() + for some_friend in friends: print(some_friend.name) def test_attempted_class_redefinition(): """ - A neomodel.StructuredNode class is attempted to be redefined. + A StructuredNode class is attempted to be redefined. """ def redefine_class_locally(): @@ -325,15 +407,16 @@ def redefine_class_locally(): # SomePerson here. # The internal structure of the SomePerson entity does not matter at all here. class SomePerson(BaseOtherPerson): - uid = neomodel.UniqueIdProperty() + uid = UniqueIdProperty() with pytest.raises( - neomodel.exceptions.NodeClassAlreadyDefined, + NodeClassAlreadyDefined, match=r"Class .* with labels .* already defined:.*", ): redefine_class_locally() +@mark_sync_test def test_relationship_result_resolution(): """ A query returning a "Relationship" object can now instantiate it to a data model class @@ -350,7 +433,7 @@ def test_relationship_result_resolution(): C.friends_with.connect(D) D.friends_with.connect(E) - query_data = neomodel.adb.cypher_query( + query_data = db.cypher_query( "MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) " "WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' " "RETURN DISTINCT r", @@ -361,6 +444,7 @@ def test_relationship_result_resolution(): assert isinstance(query_data[0][0][0], PersonalRelationship) +@mark_sync_test def test_properly_inherited_relationship(): """ A relationship class extends an existing relationship model that must extended the same previously associated @@ -370,11 +454,11 @@ def test_properly_inherited_relationship(): # Extends an existing relationship by adding the "relationship_strength" attribute. # `ExtendedPersonalRelationship` will now substitute `PersonalRelationship` EVERYWHERE in the system. class ExtendedPersonalRelationship(PersonalRelationship): - relationship_strength = neomodel.FloatProperty(default=random.random) + relationship_strength = FloatProperty(default=random.random) # Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson class ExtendedSomePerson(SomePerson): - friends_with = neomodel.AsyncRelationshipTo( + friends_with = RelationshipTo( "BaseOtherPerson", "FRIENDS_WITH", model=ExtendedPersonalRelationship, @@ -388,7 +472,7 @@ class ExtendedSomePerson(SomePerson): A.friends_with.connect(B) A.friends_with.connect(C) - query_data = neomodel.adb.cypher_query( + query_data = db.cypher_query( "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " "RETURN DISTINCT r", resolve_objects=True, @@ -403,20 +487,21 @@ def test_improperly_inherited_relationship(): :return: """ - class NewRelationship(neomodel.AsyncStructuredRel): - profile_match_factor = neomodel.FloatProperty() + class NewRelationship(StructuredRel): + profile_match_factor = FloatProperty() with pytest.raises( - neomodel.RelationshipClassRedefined, + RelationshipClassRedefined, match=r"Relationship of type .* redefined as .*", ): class NewSomePerson(SomePerson): - friends_with = neomodel.AsyncRelationshipTo( + friends_with = RelationshipTo( "BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship ) +@mark_sync_test def test_resolve_inexistent_relationship(): """ Attempting to resolve an inexistent relationship should raise an exception @@ -424,13 +509,13 @@ def test_resolve_inexistent_relationship(): """ # Forget about the FRIENDS_WITH Relationship. - del neomodel.adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] + del db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] with pytest.raises( - neomodel.RelationshipClassNotDefined, + RelationshipClassNotDefined, match=r"Relationship of type .* does not resolve to any of the known objects.*", ): - query_data = neomodel.adb.cypher_query( + query_data = db.cypher_query( "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " "RETURN DISTINCT r", resolve_objects=True, diff --git a/test/test_issue600.py b/test/sync/test_issue600.py similarity index 60% rename from test/test_issue600.py rename to test/sync/test_issue600.py index 377dc700..d88d8eb0 100644 --- a/test/test_issue600.py +++ b/test/sync/test_issue600.py @@ -4,13 +4,8 @@ The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/600 """ -import datetime -import os -import random - -import pytest - -import neomodel +from test._async_compat import mark_sync_test +from neomodel import StructuredNode, Relationship, StructuredRel try: basestring @@ -18,7 +13,7 @@ basestring = str -class Class1(neomodel.AsyncStructuredRel): +class Class1(StructuredRel): pass @@ -30,36 +25,37 @@ class SubClass2(Class1): pass -class RelationshipDefinerSecondSibling(neomodel.AsyncStructuredNode): - rel_1 = neomodel.AsyncRelationship( +class RelationshipDefinerSecondSibling(StructuredNode): + rel_1 = Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 ) - rel_2 = neomodel.AsyncRelationship( + rel_2 = Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1 ) - rel_3 = neomodel.AsyncRelationship( + rel_3 = Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2 ) -class RelationshipDefinerParentLast(neomodel.AsyncStructuredNode): - rel_2 = neomodel.AsyncRelationship( +class RelationshipDefinerParentLast(StructuredNode): + rel_2 = Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 ) - rel_3 = neomodel.AsyncRelationship( + rel_3 = Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2 ) - rel_1 = neomodel.AsyncRelationship( + rel_1 = Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1 ) # Test cases +@mark_sync_test def test_relationship_definer_second_sibling(): # Create a few entities - A = RelationshipDefinerSecondSibling.get_or_create({})[0] - B = RelationshipDefinerSecondSibling.get_or_create({})[0] - C = RelationshipDefinerSecondSibling.get_or_create({})[0] + A = (RelationshipDefinerSecondSibling.get_or_create({}))[0] + B = (RelationshipDefinerSecondSibling.get_or_create({}))[0] + C = (RelationshipDefinerSecondSibling.get_or_create({}))[0] # Add connections A.rel_1.connect(B) @@ -72,11 +68,12 @@ def test_relationship_definer_second_sibling(): C.delete() +@mark_sync_test def test_relationship_definer_parent_last(): # Create a few entities - A = RelationshipDefinerParentLast.get_or_create({})[0] - B = RelationshipDefinerParentLast.get_or_create({})[0] - C = RelationshipDefinerParentLast.get_or_create({})[0] + A = (RelationshipDefinerParentLast.get_or_create({}))[0] + B = (RelationshipDefinerParentLast.get_or_create({}))[0] + C = (RelationshipDefinerParentLast.get_or_create({}))[0] # Add connections A.rel_1.connect(B) diff --git a/test/test_issue112.py b/test/test_issue112.py deleted file mode 100644 index d20b53ac..00000000 --- a/test/test_issue112.py +++ /dev/null @@ -1,16 +0,0 @@ -from neomodel import AsyncRelationshipTo, AsyncStructuredNode - - -class SomeModel(AsyncStructuredNode): - test = AsyncRelationshipTo("SomeModel", "SELF") - - -def test_len_relationship(): - t1 = SomeModel().save() - t2 = SomeModel().save() - - t1.test.connect(t2) - l = len(t1.test.all()) - - assert l - assert l == 1 diff --git a/test/test_label_drop.py b/test/test_label_drop.py deleted file mode 100644 index 1e8f7112..00000000 --- a/test/test_label_drop.py +++ /dev/null @@ -1,46 +0,0 @@ -from neo4j.exceptions import ClientError - -from neomodel import AsyncStructuredNode, StringProperty, config -from neomodel.async_.core import adb - -config.AUTO_INSTALL_LABELS = True - - -class ConstraintAndIndex(AsyncStructuredNode): - name = StringProperty(unique_index=True) - last_name = StringProperty(index=True) - - -def test_drop_labels(): - constraints_before = adb.list_constraints() - indexes_before = adb.list_indexes(exclude_token_lookup=True) - - assert len(constraints_before) > 0 - assert len(indexes_before) > 0 - - adb.remove_all_labels() - - constraints = adb.list_constraints() - indexes = adb.list_indexes(exclude_token_lookup=True) - - assert len(constraints) == 0 - assert len(indexes) == 0 - - # Recreating all old constraints and indexes - for constraint in constraints_before: - constraint_type_clause = "UNIQUE" - if constraint["type"] == "NODE_PROPERTY_EXISTENCE": - constraint_type_clause = "NOT NULL" - elif constraint["type"] == "NODE_KEY": - constraint_type_clause = "NODE KEY" - - adb.cypher_query( - f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}' - ) - for index in indexes_before: - try: - adb.cypher_query( - f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})' - ) - except ClientError: - pass From 8bba824d574c3b80458116d27c9e22fc61d09606 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 2 Jan 2024 12:53:00 +0100 Subject: [PATCH 25/73] More tests --- neomodel/async_/core.py | 5 +- neomodel/sync_/core.py | 1 + test/async_/test_dbms_awareness.py | 10 +- test/async_/test_label_drop.py | 47 ++++++ test/{ => async_}/test_label_install.py | 91 ++++++------ test/sync/test_label_drop.py | 47 ++++++ test/sync/test_label_install.py | 184 ++++++++++++++++++++++++ 7 files changed, 328 insertions(+), 57 deletions(-) create mode 100644 test/async_/test_label_drop.py rename test/{ => async_}/test_label_install.py (66%) create mode 100644 test/sync/test_label_drop.py create mode 100644 test/sync/test_label_install.py diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 42890261..39bddc0a 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -541,7 +541,8 @@ async def list_constraints(self) -> Sequence[dict]: return constraints_as_dict - def version_is_higher_than(self, version_tag: str) -> bool: + @ensure_connection + async def version_is_higher_than(self, version_tag: str) -> bool: """Returns true if the database version is higher or equal to a given tag Args: @@ -750,7 +751,7 @@ async def _create_relationship_index( async def _create_relationship_constraint( self, relationship_type: str, property_name: str, stdout ): - if self.version_is_higher_than("5.7"): + if await self.version_is_higher_than("5.7"): try: await self.cypher_query( f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 8b0e1a7e..42dd0c38 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -539,6 +539,7 @@ def list_constraints(self) -> Sequence[dict]: return constraints_as_dict + @ensure_connection def version_is_higher_than(self, version_tag: str) -> bool: """Returns true if the database version is higher or equal to a given tag diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py index 7399666a..f9f7b7b2 100644 --- a/test/async_/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -10,12 +10,12 @@ ) async def test_version_awareness(): assert adb.database_version == "5.7.0" - assert adb.version_is_higher_than("5.7") - assert adb.version_is_higher_than("5.6.0") - assert adb.version_is_higher_than("5") - assert adb.version_is_higher_than("4") + assert await adb.version_is_higher_than("5.7") + assert await adb.version_is_higher_than("5.6.0") + assert await adb.version_is_higher_than("5") + assert await adb.version_is_higher_than("4") - assert not adb.version_is_higher_than("5.8") + assert not await adb.version_is_higher_than("5.8") @mark_async_test diff --git a/test/async_/test_label_drop.py b/test/async_/test_label_drop.py new file mode 100644 index 00000000..834f47e3 --- /dev/null +++ b/test/async_/test_label_drop.py @@ -0,0 +1,47 @@ +from neo4j.exceptions import ClientError +from test._async_compat import mark_async_test + +from neomodel import AsyncStructuredNode, StringProperty +from neomodel.async_.core import adb + + +class ConstraintAndIndex(AsyncStructuredNode): + name = StringProperty(unique_index=True) + last_name = StringProperty(index=True) + + +@mark_async_test +async def test_drop_labels(): + await adb.install_labels(ConstraintAndIndex) + constraints_before = await adb.list_constraints() + indexes_before = await adb.list_indexes(exclude_token_lookup=True) + + assert len(constraints_before) > 0 + assert len(indexes_before) > 0 + + await adb.remove_all_labels() + + constraints = await adb.list_constraints() + indexes = await adb.list_indexes(exclude_token_lookup=True) + + assert len(constraints) == 0 + assert len(indexes) == 0 + + # Recreating all old constraints and indexes + for constraint in constraints_before: + constraint_type_clause = "UNIQUE" + if constraint["type"] == "NODE_PROPERTY_EXISTENCE": + constraint_type_clause = "NOT NULL" + elif constraint["type"] == "NODE_KEY": + constraint_type_clause = "NODE KEY" + + await adb.cypher_query( + f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}' + ) + for index in indexes_before: + try: + await adb.cypher_query( + f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})' + ) + except ClientError: + pass diff --git a/test/test_label_install.py b/test/async_/test_label_install.py similarity index 66% rename from test/test_label_install.py rename to test/async_/test_label_install.py index 0061ed71..09520b70 100644 --- a/test/test_label_install.py +++ b/test/async_/test_label_install.py @@ -1,4 +1,5 @@ import pytest +from test._async_compat import mark_async_test from neomodel import ( AsyncRelationshipTo, @@ -6,13 +7,10 @@ AsyncStructuredRel, StringProperty, UniqueIdProperty, - config, ) from neomodel.async_.core import adb from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported -config.AUTO_INSTALL_LABELS = False - class NodeWithIndex(AsyncStructuredNode): name = StringProperty(index=True) @@ -45,61 +43,50 @@ class SomeNotUniqueNode(AsyncStructuredNode): id_ = UniqueIdProperty(db_property="id") -config.AUTO_INSTALL_LABELS = True - - -def test_labels_were_not_installed(): - bob = NodeWithConstraint(name="bob").save() - bob2 = NodeWithConstraint(name="bob").save() - bob3 = NodeWithConstraint(name="bob").save() - assert bob.element_id != bob3.element_id - - for n in NodeWithConstraint.nodes.all(): - n.delete() - - -def test_install_all(): - adb.drop_constraints() - adb.install_labels(AbstractNode) +@mark_async_test +async def test_install_all(): + await adb.drop_constraints() + await adb.install_labels(AbstractNode) # run install all labels - adb.install_all_labels() + await adb.install_all_labels() - indexes = adb.list_indexes() + indexes = await adb.list_indexes() index_names = [index["name"] for index in indexes] assert "index_INDEXED_REL_indexed_rel_prop" in index_names - constraints = adb.list_constraints() + constraints = await adb.list_constraints() constraint_names = [constraint["name"] for constraint in constraints] assert "constraint_unique_NodeWithConstraint_name" in constraint_names assert "constraint_unique_SomeNotUniqueNode_id" in constraint_names # remove constraint for above test - _drop_constraints_for_label_and_property("NoConstraintsSetup", "name") + await _drop_constraints_for_label_and_property("NoConstraintsSetup", "name") -def test_install_label_twice(capsys): +@mark_async_test +async def test_install_label_twice(capsys): expected_std_out = ( "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" ) - adb.install_labels(AbstractNode) - adb.install_labels(AbstractNode) + await adb.install_labels(AbstractNode) + await adb.install_labels(AbstractNode) - adb.install_labels(NodeWithIndex) - adb.install_labels(NodeWithIndex, quiet=False) + await adb.install_labels(NodeWithIndex) + await adb.install_labels(NodeWithIndex, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - adb.install_labels(NodeWithConstraint) - adb.install_labels(NodeWithConstraint, quiet=False) + await adb.install_labels(NodeWithConstraint) + await adb.install_labels(NodeWithConstraint, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - adb.install_labels(OtherNodeWithRelationship) - adb.install_labels(OtherNodeWithRelationship, quiet=False) + await adb.install_labels(OtherNodeWithRelationship) + await adb.install_labels(OtherNodeWithRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - if adb.version_is_higher_than("5.7"): + if await adb.version_is_higher_than("5.7"): class UniqueIndexRelationship(AsyncStructuredRel): unique_index_rel_prop = StringProperty(unique_index=True) @@ -109,24 +96,25 @@ class OtherNodeWithUniqueIndexRelationship(AsyncStructuredNode): NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship ) - adb.install_labels(OtherNodeWithUniqueIndexRelationship) - adb.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) + await adb.install_labels(OtherNodeWithUniqueIndexRelationship) + await adb.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out -def test_install_labels_db_property(capsys): - adb.drop_constraints() - adb.install_labels(SomeNotUniqueNode, quiet=False) +@mark_async_test +async def test_install_labels_db_property(capsys): + await adb.drop_constraints() + await adb.install_labels(SomeNotUniqueNode, quiet=False) captured = capsys.readouterr() assert "id" in captured.out # make sure that the id_ constraint doesn't exist - constraint_names = _drop_constraints_for_label_and_property( + constraint_names = await _drop_constraints_for_label_and_property( "SomeNotUniqueNode", "id_" ) assert constraint_names == [] # make sure the id constraint exists and can be removed - _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") + await _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") @pytest.mark.skipif( @@ -151,8 +139,9 @@ class NodeWithUniqueIndexRelationship(AsyncStructuredNode): ) +@mark_async_test @pytest.mark.skipif(not adb.version_is_higher_than("5.7"), reason="Supported from 5.7") -def test_relationship_unique_index(): +async def test_relationship_unique_index(): class UniqueIndexRelationshipBis(AsyncStructuredRel): name = StringProperty(unique_index=True) @@ -166,21 +155,23 @@ class NodeWithUniqueIndexRelationship(AsyncStructuredNode): model=UniqueIndexRelationshipBis, ) - adb.install_labels(UniqueIndexRelationshipBis) - node1 = NodeWithUniqueIndexRelationship().save() - node2 = TargetNodeForUniqueIndexRelationship().save() - node3 = TargetNodeForUniqueIndexRelationship().save() - rel1 = node1.has_rel.connect(node2, {"name": "rel1"}) + await adb.install_labels(NodeWithUniqueIndexRelationship) + node1 = await NodeWithUniqueIndexRelationship().save() + node2 = await TargetNodeForUniqueIndexRelationship().save() + node3 = await TargetNodeForUniqueIndexRelationship().save() + rel1 = await node1.has_rel.connect(node2, {"name": "rel1"}) with pytest.raises( ConstraintValidationFailed, match=r".*already exists with type `UNIQUE_INDEX_REL_BIS` and property `name`.*", ): - rel2 = node1.has_rel.connect(node3, {"name": "rel1"}) + rel2 = await node1.has_rel.connect(node3, {"name": "rel1"}) -def _drop_constraints_for_label_and_property(label: str = None, property: str = None): - results, meta = adb.cypher_query("SHOW CONSTRAINTS") +async def _drop_constraints_for_label_and_property( + label: str = None, property: str = None +): + results, meta = await adb.cypher_query("SHOW CONSTRAINTS") results_as_dict = [dict(zip(meta, row)) for row in results] constraint_names = [ constraint @@ -188,6 +179,6 @@ def _drop_constraints_for_label_and_property(label: str = None, property: str = if constraint["labelsOrTypes"] == label and constraint["properties"] == property ] for constraint_name in constraint_names: - adb.cypher_query(f"DROP CONSTRAINT {constraint_name}") + await adb.cypher_query(f"DROP CONSTRAINT {constraint_name}") return constraint_names diff --git a/test/sync/test_label_drop.py b/test/sync/test_label_drop.py new file mode 100644 index 00000000..016f72c1 --- /dev/null +++ b/test/sync/test_label_drop.py @@ -0,0 +1,47 @@ +from neo4j.exceptions import ClientError +from test._async_compat import mark_sync_test + +from neomodel import StructuredNode, StringProperty +from neomodel.sync_.core import db + + +class ConstraintAndIndex(StructuredNode): + name = StringProperty(unique_index=True) + last_name = StringProperty(index=True) + + +@mark_sync_test +def test_drop_labels(): + db.install_labels(ConstraintAndIndex) + constraints_before = db.list_constraints() + indexes_before = db.list_indexes(exclude_token_lookup=True) + + assert len(constraints_before) > 0 + assert len(indexes_before) > 0 + + db.remove_all_labels() + + constraints = db.list_constraints() + indexes = db.list_indexes(exclude_token_lookup=True) + + assert len(constraints) == 0 + assert len(indexes) == 0 + + # Recreating all old constraints and indexes + for constraint in constraints_before: + constraint_type_clause = "UNIQUE" + if constraint["type"] == "NODE_PROPERTY_EXISTENCE": + constraint_type_clause = "NOT NULL" + elif constraint["type"] == "NODE_KEY": + constraint_type_clause = "NODE KEY" + + db.cypher_query( + f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}' + ) + for index in indexes_before: + try: + db.cypher_query( + f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})' + ) + except ClientError: + pass diff --git a/test/sync/test_label_install.py b/test/sync/test_label_install.py new file mode 100644 index 00000000..22d70348 --- /dev/null +++ b/test/sync/test_label_install.py @@ -0,0 +1,184 @@ +import pytest +from test._async_compat import mark_sync_test + +from neomodel import ( + RelationshipTo, + StructuredNode, + StructuredRel, + StringProperty, + UniqueIdProperty, +) +from neomodel.sync_.core import db +from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported + + +class NodeWithIndex(StructuredNode): + name = StringProperty(index=True) + + +class NodeWithConstraint(StructuredNode): + name = StringProperty(unique_index=True) + + +class NodeWithRelationship(StructuredNode): + ... + + +class IndexedRelationship(StructuredRel): + indexed_rel_prop = StringProperty(index=True) + + +class OtherNodeWithRelationship(StructuredNode): + has_rel = RelationshipTo( + NodeWithRelationship, "INDEXED_REL", model=IndexedRelationship + ) + + +class AbstractNode(StructuredNode): + __abstract_node__ = True + name = StringProperty(unique_index=True) + + +class SomeNotUniqueNode(StructuredNode): + id_ = UniqueIdProperty(db_property="id") + + +@mark_sync_test +def test_install_all(): + db.drop_constraints() + db.install_labels(AbstractNode) + # run install all labels + db.install_all_labels() + + indexes = db.list_indexes() + index_names = [index["name"] for index in indexes] + assert "index_INDEXED_REL_indexed_rel_prop" in index_names + + constraints = db.list_constraints() + constraint_names = [constraint["name"] for constraint in constraints] + assert "constraint_unique_NodeWithConstraint_name" in constraint_names + assert "constraint_unique_SomeNotUniqueNode_id" in constraint_names + + # remove constraint for above test + _drop_constraints_for_label_and_property("NoConstraintsSetup", "name") + + +@mark_sync_test +def test_install_label_twice(capsys): + expected_std_out = ( + "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" + ) + db.install_labels(AbstractNode) + db.install_labels(AbstractNode) + + db.install_labels(NodeWithIndex) + db.install_labels(NodeWithIndex, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + db.install_labels(NodeWithConstraint) + db.install_labels(NodeWithConstraint, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + db.install_labels(OtherNodeWithRelationship) + db.install_labels(OtherNodeWithRelationship, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + if db.version_is_higher_than("5.7"): + + class UniqueIndexRelationship(StructuredRel): + unique_index_rel_prop = StringProperty(unique_index=True) + + class OtherNodeWithUniqueIndexRelationship(StructuredNode): + has_rel = RelationshipTo( + NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship + ) + + db.install_labels(OtherNodeWithUniqueIndexRelationship) + db.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + +@mark_sync_test +def test_install_labels_db_property(capsys): + db.drop_constraints() + db.install_labels(SomeNotUniqueNode, quiet=False) + captured = capsys.readouterr() + assert "id" in captured.out + # make sure that the id_ constraint doesn't exist + constraint_names = _drop_constraints_for_label_and_property( + "SomeNotUniqueNode", "id_" + ) + assert constraint_names == [] + # make sure the id constraint exists and can be removed + _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") + + +@pytest.mark.skipif( + db.version_is_higher_than("5.7"), reason="Not supported before 5.7" +) +def test_relationship_unique_index_not_supported(): + class UniqueIndexRelationship(StructuredRel): + name = StringProperty(unique_index=True) + + class TargetNodeForUniqueIndexRelationship(StructuredNode): + pass + + with pytest.raises( + FeatureNotSupported, match=r".*Please upgrade to Neo4j 5.7 or higher" + ): + + class NodeWithUniqueIndexRelationship(StructuredNode): + has_rel = RelationshipTo( + TargetNodeForUniqueIndexRelationship, + "UNIQUE_INDEX_REL", + model=UniqueIndexRelationship, + ) + + +@mark_sync_test +@pytest.mark.skipif(not db.version_is_higher_than("5.7"), reason="Supported from 5.7") +def test_relationship_unique_index(): + class UniqueIndexRelationshipBis(StructuredRel): + name = StringProperty(unique_index=True) + + class TargetNodeForUniqueIndexRelationship(StructuredNode): + pass + + class NodeWithUniqueIndexRelationship(StructuredNode): + has_rel = RelationshipTo( + TargetNodeForUniqueIndexRelationship, + "UNIQUE_INDEX_REL_BIS", + model=UniqueIndexRelationshipBis, + ) + + db.install_labels(NodeWithUniqueIndexRelationship) + node1 = NodeWithUniqueIndexRelationship().save() + node2 = TargetNodeForUniqueIndexRelationship().save() + node3 = TargetNodeForUniqueIndexRelationship().save() + rel1 = node1.has_rel.connect(node2, {"name": "rel1"}) + + with pytest.raises( + ConstraintValidationFailed, + match=r".*already exists with type `UNIQUE_INDEX_REL_BIS` and property `name`.*", + ): + rel2 = node1.has_rel.connect(node3, {"name": "rel1"}) + + +def _drop_constraints_for_label_and_property( + label: str = None, property: str = None +): + results, meta = db.cypher_query("SHOW CONSTRAINTS") + results_as_dict = [dict(zip(meta, row)) for row in results] + constraint_names = [ + constraint + for constraint in results_as_dict + if constraint["labelsOrTypes"] == label and constraint["properties"] == property + ] + for constraint_name in constraint_names: + db.cypher_query(f"DROP CONSTRAINT {constraint_name}") + + return constraint_names From bd7c2981117e852260fc4217b34f05f9a8b95d4e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 11:48:26 +0100 Subject: [PATCH 26/73] More tests and fixes --- neomodel/async_/core.py | 12 +- neomodel/async_/match.py | 4 +- neomodel/async_/relationship.py | 9 +- neomodel/async_/relationship_manager.py | 6 +- neomodel/sync_/core.py | 10 +- neomodel/sync_/relationship.py | 9 +- neomodel/sync_/relationship_manager.py | 6 +- test/async_/__init__.py | 0 test/async_/test_match_api.py | 543 ++++++++++++++++++ test/{ => async_}/test_migration_neo4j_5.py | 14 +- test/async_/test_models.py | 360 ++++++++++++ test/async_/test_multiprocessing.py | 25 + test/async_/test_paths.py | 95 +++ test/async_/test_properties.py | 455 +++++++++++++++ test/async_/test_relationship_models.py | 166 ++++++ test/async_/test_relationships.py | 209 +++++++ .../test_relative_relationships.py | 12 +- test/sync/__init__.py | 0 test/{ => sync}/test_match_api.py | 138 +++-- test/sync/test_migration_neo4j_5.py | 78 +++ test/{ => sync}/test_models.py | 75 ++- test/{ => sync}/test_multiprocessing.py | 12 +- test/{ => sync}/test_paths.py | 32 +- test/{ => sync}/test_properties.py | 56 +- test/{ => sync}/test_relationship_models.py | 33 +- test/{ => sync}/test_relationships.py | 54 +- test/sync/test_relative_relationships.py | 23 + test/test_scripts.py | 40 +- test/test_transactions.py | 182 ------ 29 files changed, 2262 insertions(+), 396 deletions(-) create mode 100644 test/async_/__init__.py create mode 100644 test/async_/test_match_api.py rename test/{ => async_}/test_migration_neo4j_5.py (84%) create mode 100644 test/async_/test_models.py create mode 100644 test/async_/test_multiprocessing.py create mode 100644 test/async_/test_paths.py create mode 100644 test/async_/test_properties.py create mode 100644 test/async_/test_relationship_models.py create mode 100644 test/async_/test_relationships.py rename test/{ => async_}/test_relative_relationships.py (63%) create mode 100644 test/sync/__init__.py rename test/{ => sync}/test_match_api.py (82%) create mode 100644 test/sync/test_migration_neo4j_5.py rename test/{ => sync}/test_models.py (84%) rename test/{ => sync}/test_multiprocessing.py (57%) rename test/{ => sync}/test_paths.py (75%) rename test/{ => sync}/test_properties.py (89%) rename test/{ => sync}/test_relationship_models.py (86%) rename test/{ => sync}/test_relationships.py (79%) create mode 100644 test/sync/test_relative_relationships.py delete mode 100644 test/test_transactions.py diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 39bddc0a..3deee79f 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -248,7 +248,7 @@ async def begin(self, access_mode=None, **parameters): and self._active_transaction is not None ): raise SystemError("Transaction in progress") - self._session: AsyncSession = await self.driver.session( + self._session: AsyncSession = self.driver.session( default_access_mode=access_mode, database=self._database_name, impersonated_user=self.impersonated_user, @@ -1435,9 +1435,10 @@ async def labels(self): :rtype: list """ self._pre_action_check("labels") - return await self.cypher( + result = await self.cypher( f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" - )[0][0][0] + ) + return result[0][0][0] def _pre_action_check(self, action): if hasattr(self, "deleted") and self.deleted: @@ -1455,9 +1456,10 @@ async def refresh(self): """ self._pre_action_check("refresh") if hasattr(self, "element_id"): - request = await self.cypher( + results = await self.cypher( f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" - )[0] + ) + request = results[0] if not request or not request[0]: raise self.__class__.DoesNotExist("Can't refresh non existent node") node = self.inflate(request[0][0]) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 2c8388ac..9b5769df 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -681,7 +681,7 @@ async def _count(self): results, _ = await adb.cypher_query(query, self._query_params) return int(results[0][0]) - def _contains(self, node_element_id): + async def _contains(self, node_element_id): # inject id = into ast if not self._ast.return_clause: print(self._ast.additional_return) @@ -690,7 +690,7 @@ def _contains(self, node_element_id): place_holder = self._register_place_holder(ident + "_contains") self._ast.where.append(f"{adb.get_id_method()}({ident}) = ${place_holder}") self._query_params[place_holder] = node_element_id - return self._count() >= 1 + return await self._count() >= 1 async def _execute(self, lazy=False): if lazy: diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index eab91249..65c51627 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -123,7 +123,7 @@ async def start_node(self): :return: StructuredNode """ - test = await adb.cypher_query( + results = await adb.cypher_query( f""" MATCH (aNode) WHERE {adb.get_id_method()}(aNode)=$start_node_element_id @@ -132,7 +132,7 @@ async def start_node(self): {"start_node_element_id": self._start_node_element_id}, resolve_objects=True, ) - return test[0][0][0] + return results[0][0][0] async def end_node(self): """ @@ -140,7 +140,7 @@ async def end_node(self): :return: StructuredNode """ - return await adb.cypher_query( + results = await adb.cypher_query( f""" MATCH (aNode) WHERE {adb.get_id_method()}(aNode)=$end_node_element_id @@ -148,7 +148,8 @@ async def end_node(self): """, {"end_node_element_id": self._end_node_element_id}, resolve_objects=True, - )[0][0][0] + ) + return results[0][0][0] @classmethod def inflate(cls, rel): diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 35d378ba..5bfe8dc0 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -170,7 +170,8 @@ async def relationship(self, node): + my_rel + f" WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r LIMIT 1" ) - rels = await self.source.cypher(q, {"them": node.element_id})[0] + results = await self.source.cypher(q, {"them": node.element_id}) + rels = results[0] if not rels: return @@ -190,7 +191,8 @@ async def all_relationships(self, node): my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) q = f"MATCH {my_rel} WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r " - rels = await self.source.cypher(q, {"them": node.element_id})[0] + results = await self.source.cypher(q, {"them": node.element_id}) + rels = results[0] if not rels: return [] diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 42dd0c38..9be231f7 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -1431,9 +1431,10 @@ def labels(self): :rtype: list """ self._pre_action_check("labels") - return self.cypher( + result = self.cypher( f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" - )[0][0][0] + ) + return result[0][0][0] def _pre_action_check(self, action): if hasattr(self, "deleted") and self.deleted: @@ -1451,9 +1452,10 @@ def refresh(self): """ self._pre_action_check("refresh") if hasattr(self, "element_id"): - request = self.cypher( + results = self.cypher( f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" - )[0] + ) + request = results[0] if not request or not request[0]: raise self.__class__.DoesNotExist("Can't refresh non existent node") node = self.inflate(request[0][0]) diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index c246df38..3c6aa523 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -123,7 +123,7 @@ def start_node(self): :return: StructuredNode """ - test = db.cypher_query( + results = db.cypher_query( f""" MATCH (aNode) WHERE {db.get_id_method()}(aNode)=$start_node_element_id @@ -132,7 +132,7 @@ def start_node(self): {"start_node_element_id": self._start_node_element_id}, resolve_objects=True, ) - return test[0][0][0] + return results[0][0][0] def end_node(self): """ @@ -140,7 +140,7 @@ def end_node(self): :return: StructuredNode """ - return db.cypher_query( + results = db.cypher_query( f""" MATCH (aNode) WHERE {db.get_id_method()}(aNode)=$end_node_element_id @@ -148,7 +148,8 @@ def end_node(self): """, {"end_node_element_id": self._end_node_element_id}, resolve_objects=True, - )[0][0][0] + ) + return results[0][0][0] @classmethod def inflate(cls, rel): diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 54dcb2be..1d31c2ca 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -165,7 +165,8 @@ def relationship(self, node): + my_rel + f" WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r LIMIT 1" ) - rels = self.source.cypher(q, {"them": node.element_id})[0] + results = self.source.cypher(q, {"them": node.element_id}) + rels = results[0] if not rels: return @@ -185,7 +186,8 @@ def all_relationships(self, node): my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) q = f"MATCH {my_rel} WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r " - rels = self.source.cypher(q, {"them": node.element_id})[0] + results = self.source.cypher(q, {"them": node.element_id}) + rels = results[0] if not rels: return [] diff --git a/test/async_/__init__.py b/test/async_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py new file mode 100644 index 00000000..4b96875c --- /dev/null +++ b/test/async_/test_match_api.py @@ -0,0 +1,543 @@ +from datetime import datetime +from test._async_compat import mark_async_test + +from pytest import raises + +from neomodel import ( + INCOMING, + AsyncRelationshipFrom, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + DateTimeProperty, + IntegerProperty, + Q, + StringProperty, +) +from neomodel.async_.match import ( + AsyncNodeSet, + AsyncQueryBuilder, + AsyncTraversal, + Optional, +) +from neomodel.exceptions import MultipleNodesReturned + + +class SupplierRel(AsyncStructuredRel): + since = DateTimeProperty(default=datetime.now) + courier = StringProperty() + + +class Supplier(AsyncStructuredNode): + name = StringProperty() + delivery_cost = IntegerProperty() + coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS") + + +class Species(AsyncStructuredNode): + name = StringProperty() + coffees = AsyncRelationshipFrom( + "Coffee", "COFFEE SPECIES", model=AsyncStructuredRel + ) + + +class Coffee(AsyncStructuredNode): + name = StringProperty(unique_index=True) + price = IntegerProperty() + suppliers = AsyncRelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) + species = AsyncRelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel) + id_ = IntegerProperty() + + +class Extension(AsyncStructuredNode): + extension = AsyncRelationshipTo("Extension", "extension") + + +# TODO : Maybe split these tests into separate async and sync (not transpiled) +# That would allow to test "Coffee.nodes" for sync instead of Coffee.nodes.all() + + +@mark_async_test +async def test_filter_exclude_via_labels(): + await Coffee(name="Java", price=99).save() + + node_set = AsyncNodeSet(Coffee) + qb = AsyncQueryBuilder(node_set).build_ast() + + results = await qb._execute() + + assert "(coffee:Coffee)" in qb._ast.match + assert qb._ast.result_class + assert len(results) == 1 + assert isinstance(results[0], Coffee) + assert results[0].name == "Java" + + # with filter and exclude + await Coffee(name="Kenco", price=3).save() + node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") + qb = AsyncQueryBuilder(node_set).build_ast() + + results = await qb._execute() + assert "(coffee:Coffee)" in qb._ast.match + assert "NOT" in qb._ast.where[0] + assert len(results) == 1 + assert results[0].name == "Kenco" + + +@mark_async_test +async def test_simple_has_via_label(): + nescafe = await Coffee(name="Nescafe", price=99).save() + tesco = await Supplier(name="Tesco", delivery_cost=2).save() + await nescafe.suppliers.connect(tesco) + + ns = AsyncNodeSet(Coffee).has(suppliers=True) + qb = AsyncQueryBuilder(ns).build_ast() + results = await qb._execute() + assert "COFFEE SUPPLIERS" in qb._ast.where[0] + assert len(results) == 1 + assert results[0].name == "Nescafe" + + await Coffee(name="nespresso", price=99).save() + ns = AsyncNodeSet(Coffee).has(suppliers=False) + qb = AsyncQueryBuilder(ns).build_ast() + results = await qb._execute() + assert len(results) > 0 + assert "NOT" in qb._ast.where[0] + + +@mark_async_test +async def test_get(): + await Coffee(name="1", price=3).save() + assert await Coffee.nodes.get(name="1") + + with raises(Coffee.DoesNotExist): + await Coffee.nodes.get(name="2") + + await Coffee(name="2", price=3).save() + + with raises(MultipleNodesReturned): + await Coffee.nodes.get(price=3) + + +@mark_async_test +async def test_simple_traverse_with_filter(): + nescafe = await Coffee(name="Nescafe2", price=99).save() + tesco = await Supplier(name="Sainsburys", delivery_cost=2).save() + await nescafe.suppliers.connect(tesco) + + qb = AsyncQueryBuilder( + AsyncNodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) + ) + + results = await qb.build_ast()._execute() + + assert qb._ast.lookup + assert qb._ast.match + assert qb._ast.return_clause.startswith("suppliers") + assert len(results) == 1 + assert results[0].name == "Sainsburys" + + +@mark_async_test +async def test_double_traverse(): + nescafe = await Coffee(name="Nescafe plus", price=99).save() + tesco = await Supplier(name="Asda", delivery_cost=2).save() + await nescafe.suppliers.connect(tesco) + await tesco.coffees.connect(await Coffee(name="Decafe", price=2).save()) + + ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match() + qb = AsyncQueryBuilder(ns).build_ast() + + results = await qb._execute() + assert len(results) == 2 + assert results[0].name == "Decafe" + assert results[1].name == "Nescafe plus" + + +@mark_async_test +async def test_count(): + await Coffee(name="Nescafe Gold", price=99).save() + count = await AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast()._count() + assert count > 0 + + await Coffee(name="Kawa", price=27).save() + node_set = AsyncNodeSet(source=Coffee) + node_set.skip = 1 + node_set.limit = 1 + count = await AsyncQueryBuilder(node_set).build_ast()._count() + assert count == 1 + + +@mark_async_test +async def test_len_and_iter_and_bool(): + iterations = 0 + + await Coffee(name="Icelands finest").save() + + for c in await Coffee.nodes.all(): + iterations += 1 + await c.delete() + + assert iterations > 0 + + assert len(await Coffee.nodes.all()) == 0 + + +@mark_async_test +async def test_slice(): + for c in await Coffee.nodes.all(): + await c.delete() + + await Coffee(name="Icelands finest").save() + await Coffee(name="Britains finest").save() + await Coffee(name="Japans finest").save() + + # TODO : Make slice work with async + # Doing await (Coffee.nodes.all())[1:] fetches without slicing + assert len(list(Coffee.nodes.all()[1:])) == 2 + assert len(list(Coffee.nodes.all()[:1])) == 1 + assert isinstance(Coffee.nodes[1], Coffee) + assert isinstance(Coffee.nodes[0], Coffee) + assert len(list(Coffee.nodes.all()[1:2])) == 1 + + +@mark_async_test +async def test_issue_208(): + # calls to match persist across queries. + + b = await Coffee(name="basics").save() + l = await Supplier(name="lidl").save() + a = await Supplier(name="aldi").save() + + await b.suppliers.connect(l, {"courier": "fedex"}) + await b.suppliers.connect(a, {"courier": "dhl"}) + + assert len(await b.suppliers.match(courier="fedex").all()) + assert len(await b.suppliers.match(courier="dhl").all()) + + +@mark_async_test +async def test_issue_589(): + node1 = await Extension().save() + node2 = await Extension().save() + await node1.extension.connect(node2) + assert node2 in await node1.extension.all() + + +# TODO : Fix the ValueError not raised +@mark_async_test +async def test_contains(): + expensive = await Coffee(price=1000, name="Pricey").save() + asda = await Coffee(name="Asda", price=1).save() + + assert expensive in await Coffee.nodes.filter(price__gt=999).all() + assert asda not in await Coffee.nodes.filter(price__gt=999).all() + + # bad value raises + with raises(ValueError): + 2 in Coffee.nodes + + # unsaved + with raises(ValueError): + Coffee() in Coffee.nodes + + +@mark_async_test +async def test_order_by(): + for c in await Coffee.nodes.all(): + await c.delete() + + c1 = await Coffee(name="Icelands finest", price=5).save() + c2 = await Coffee(name="Britains finest", price=10).save() + c3 = await Coffee(name="Japans finest", price=35).save() + + assert Coffee.nodes.order_by("price")[0].price == 5 + assert Coffee.nodes.order_by("-price")[0].price == 35 + + ns = await Coffee.nodes.order_by("-price") + qb = AsyncQueryBuilder(ns).build_ast() + assert qb._ast.order_by + ns = ns.order_by(None) + qb = AsyncQueryBuilder(ns).build_ast() + assert not qb._ast.order_by + ns = ns.order_by("?") + qb = AsyncQueryBuilder(ns).build_ast() + assert qb._ast.with_clause == "coffee, rand() as r" + assert qb._ast.order_by == "r" + + with raises( + ValueError, + match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", + ): + Coffee.nodes.order_by("id") + + # Test order by on a relationship + l = await Supplier(name="lidl2").save() + await l.coffees.connect(c1) + await l.coffees.connect(c2) + await l.coffees.connect(c3) + + ordered_n = [n for n in await l.coffees.order_by("name").all()] + assert ordered_n[0] == c2 + assert ordered_n[1] == c1 + assert ordered_n[2] == c3 + + +@mark_async_test +async def test_extra_filters(): + for c in await Coffee.nodes.all(): + await c.delete() + + c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() + c2 = await Coffee(name="Britains finest", price=10, id_=2).save() + c3 = await Coffee(name="Japans finest", price=35, id_=3).save() + c4 = await Coffee(name="US extra-fine", price=None, id_=4).save() + + coffees_5_10 = await Coffee.nodes.filter(price__in=[10, 5]).all() + assert len(coffees_5_10) == 2, "unexpected number of results" + assert c1 in coffees_5_10, "doesnt contain 5 price coffee" + assert c2 in coffees_5_10, "doesnt contain 10 price coffee" + + finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest").all() + assert len(finest_coffees) == 3, "unexpected number of results" + assert c1 in finest_coffees, "doesnt contain 1st finest coffee" + assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" + assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" + + unpriced_coffees = await Coffee.nodes.filter(price__isnull=True).all() + assert len(unpriced_coffees) == 1, "unexpected number of results" + assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" + + coffees_with_id_gte_3 = await Coffee.nodes.filter(id___gte=3).all() + assert len(coffees_with_id_gte_3) == 2, "unexpected number of results" + assert c3 in coffees_with_id_gte_3 + assert c4 in coffees_with_id_gte_3 + + with raises( + ValueError, + match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", + ): + await Coffee.nodes.filter(elementId="4:xxx:111").all() + + +def test_traversal_definition_keys_are_valid(): + muckefuck = Coffee(name="Mukkefuck", price=1) + + with raises(ValueError): + AsyncTraversal( + muckefuck, + "a_name", + { + "node_class": Supplier, + "direction": INCOMING, + "relationship_type": "KNOWS", + "model": None, + }, + ) + + AsyncTraversal( + muckefuck, + "a_name", + { + "node_class": Supplier, + "direction": INCOMING, + "relation_type": "KNOWS", + "model": None, + }, + ) + + +@mark_async_test +async def test_empty_filters(): + """Test this case: + ``` + SomeModel.nodes.filter().filter(Q(arg1=val1)).all() + SomeModel.nodes.exclude().exclude(Q(arg1=val1)).all() + SomeModel.nodes.filter().filter(arg1=val1).all() + ``` + In django_rest_framework filter uses such as lazy function and + ``get_queryset`` function in ``GenericAPIView`` should returns + ``NodeSet`` object. + """ + + for c in await Coffee.nodes.all(): + await c.delete() + + c1 = await Coffee(name="Super", price=5, id_=1).save() + c2 = await Coffee(name="Puper", price=10, id_=2).save() + + empty_filter = Coffee.nodes.filter() + + all_coffees = await empty_filter.all() + assert len(all_coffees) == 2, "unexpected number of results" + + filter_empty_filter = empty_filter.filter(price=5) + assert len(await filter_empty_filter.all()) == 1, "unexpected number of results" + assert ( + c1 in await filter_empty_filter.all() + ), "doesnt contain c1 in ``filter_empty_filter``" + + filter_q_empty_filter = empty_filter.filter(Q(price=5)) + assert len(await filter_empty_filter.all()) == 1, "unexpected number of results" + assert ( + c1 in await filter_empty_filter.all() + ), "doesnt contain c1 in ``filter_empty_filter``" + + +@mark_async_test +async def test_q_filters(): + # Test where no children and self.connector != conn ? + for c in await Coffee.nodes.all(): + await c.delete() + + c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() + c2 = await Coffee(name="Britains finest", price=10, id_=2).save() + c3 = await Coffee(name="Japans finest", price=35, id_=3).save() + c4 = await Coffee(name="US extra-fine", price=None, id_=4).save() + c5 = await Coffee(name="Latte", price=35, id_=5).save() + c6 = await Coffee(name="Cappuccino", price=35, id_=6).save() + + coffees_5_10 = await Coffee.nodes.filter(Q(price=10) | Q(price=5)).all() + assert len(coffees_5_10) == 2, "unexpected number of results" + assert c1 in coffees_5_10, "doesnt contain 5 price coffee" + assert c2 in coffees_5_10, "doesnt contain 10 price coffee" + + coffees_5_6 = ( + await Coffee.nodes.filter(Q(name="Latte") | Q(name="Cappuccino")) + .filter(price=35) + .all() + ) + assert len(coffees_5_6) == 2, "unexpected number of results" + assert c5 in coffees_5_6, "doesnt contain 5 coffee" + assert c6 in coffees_5_6, "doesnt contain 6 coffee" + + coffees_5_6 = ( + await Coffee.nodes.filter(price=35) + .filter(Q(name="Latte") | Q(name="Cappuccino")) + .all() + ) + assert len(coffees_5_6) == 2, "unexpected number of results" + assert c5 in coffees_5_6, "doesnt contain 5 coffee" + assert c6 in coffees_5_6, "doesnt contain 6 coffee" + + finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest").all() + assert len(finest_coffees) == 3, "unexpected number of results" + assert c1 in finest_coffees, "doesnt contain 1st finest coffee" + assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" + assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" + + unpriced_coffees = await Coffee.nodes.filter(Q(price__isnull=True)).all() + assert len(unpriced_coffees) == 1, "unexpected number of results" + assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" + + coffees_with_id_gte_3 = await Coffee.nodes.filter(Q(id___gte=3)).all() + assert len(coffees_with_id_gte_3) == 4, "unexpected number of results" + assert c3 in coffees_with_id_gte_3 + assert c4 in coffees_with_id_gte_3 + assert c5 in coffees_with_id_gte_3 + assert c6 in coffees_with_id_gte_3 + + coffees_5_not_japans = await Coffee.nodes.filter( + Q(price__gt=5) & ~Q(name="Japans finest") + ).all() + assert c3 not in coffees_5_not_japans + + empty_Q_condition = await Coffee.nodes.filter(Q(price=5) | Q()).all() + assert ( + len(empty_Q_condition) == 1 + ), "undefined Q leading to unexpected number of results" + assert c1 in empty_Q_condition + + combined_coffees = await Coffee.nodes.filter( + Q(price=35), Q(name="Latte") | Q(name="Cappuccino") + ).all() + assert len(combined_coffees) == 2 + assert c5 in combined_coffees + assert c6 in combined_coffees + assert c3 not in combined_coffees + + class QQ: + pass + + with raises(TypeError): + wrong_Q = await Coffee.nodes.filter(Q(price=5) | QQ()).all() + + +def test_qbase(): + test_print_out = str(Q(price=5) | Q(price=10)) + test_repr = repr(Q(price=5) | Q(price=10)) + assert test_print_out == "(OR: ('price', 5), ('price', 10))" + assert test_repr == "" + + assert ("price", 5) in (Q(price=5) | Q(price=10)) + + test_hash = set([Q(price_lt=30) | ~Q(price=5), Q(price_lt=30) | ~Q(price=5)]) + assert len(test_hash) == 1 + + +@mark_async_test +async def test_traversal_filter_left_hand_statement(): + nescafe = await Coffee(name="Nescafe2", price=99).save() + nescafe_gold = await Coffee(name="Nescafe gold", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + biedronka = await Supplier(name="Biedronka", delivery_cost=5).save() + lidl = await Supplier(name="Lidl", delivery_cost=3).save() + + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(biedronka) + await nescafe_gold.suppliers.connect(lidl) + + lidl_supplier = ( + await AsyncNodeSet(Coffee.nodes.filter(price=11).suppliers) + .filter(delivery_cost=3) + .all() + ) + + assert lidl in lidl_supplier + + +@mark_async_test +async def test_fetch_relations(): + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe 1000", price=99).save() + nescafe_gold = await Coffee(name="Nescafe 1001", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + + result = ( + await Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert arabica in result[0] + assert robusta not in result[0] + assert tesco in result[0] + assert nescafe in result[0] + assert nescafe_gold not in result[0] + + result = ( + await Species.nodes.filter(name="Robusta") + .fetch_relations(Optional("coffees__suppliers")) + .all() + ) + assert result[0][0] is None + + # len() should only consider Suppliers + count = len( + await Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert count == 1 + + assert ( + tesco + in await Supplier.nodes.fetch_relations("coffees__species") + .filter(name="Sainsburys") + .all() + ) diff --git a/test/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py similarity index 84% rename from test/test_migration_neo4j_5.py rename to test/async_/test_migration_neo4j_5.py index c61dd312..a836b969 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/async_/test_migration_neo4j_5.py @@ -1,4 +1,5 @@ import pytest +from test._async_compat import mark_async_test from neomodel import ( AsyncRelationshipTo, @@ -23,13 +24,14 @@ class Band(AsyncStructuredNode): released = AsyncRelationshipTo(Album, relation_type="RELEASED", model=Released) -def test_read_elements_id(): - the_hives = Band(name="The Hives").save() - lex_hives = Album(name="Lex Hives").save() - released_rel = the_hives.released.connect(lex_hives) +@mark_async_test +async def test_read_elements_id(): + the_hives = await Band(name="The Hives").save() + lex_hives = await Album(name="Lex Hives").save() + released_rel = await the_hives.released.connect(lex_hives) # Validate element_id properties - assert lex_hives.element_id == the_hives.released.single().element_id + assert lex_hives.element_id == (await the_hives.released.single()).element_id assert released_rel._start_node_element_id == the_hives.element_id assert released_rel._end_node_element_id == lex_hives.element_id @@ -38,7 +40,7 @@ def test_read_elements_id(): if adb.database_version.startswith("4"): # Nodes' ids assert lex_hives.id == int(lex_hives.element_id) - assert lex_hives.id == the_hives.released.single().id + assert lex_hives.id == (await the_hives.released.single()).id # Relationships' ids assert isinstance(released_rel.element_id, int) assert released_rel.element_id == released_rel.id diff --git a/test/async_/test_models.py b/test/async_/test_models.py new file mode 100644 index 00000000..4a83a311 --- /dev/null +++ b/test/async_/test_models.py @@ -0,0 +1,360 @@ +from __future__ import print_function + +from datetime import datetime + +from test._async_compat import mark_async_test +from pytest import raises + +from neomodel import ( + AsyncStructuredNode, + AsyncStructuredRel, + DateProperty, + IntegerProperty, + StringProperty, +) +from neomodel.async_.core import adb +from neomodel.exceptions import RequiredProperty, UniqueProperty + + +class User(AsyncStructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + @property + def email_alias(self): + return self.email + + @email_alias.setter # noqa + def email_alias(self, value): + self.email = value + + +class NodeWithoutProperty(AsyncStructuredNode): + pass + + +@mark_async_test +async def test_issue_233(): + class BaseIssue233(AsyncStructuredNode): + __abstract_node__ = True + + def __getitem__(self, item): + return self.__dict__[item] + + class Issue233(BaseIssue233): + uid = StringProperty(unique_index=True, required=True) + + i = await Issue233(uid="testgetitem").save() + assert i["uid"] == "testgetitem" + + +def test_issue_72(): + user = User(email="foo@bar.com") + assert user.age is None + + +@mark_async_test +async def test_required(): + with raises(RequiredProperty): + await User(age=3).save() + + +def test_repr_and_str(): + u = User(email="robin@test.com", age=3) + print(repr(u)) + print(str(u)) + assert True + + +@mark_async_test +async def test_get_and_get_or_none(): + u = User(email="robin@test.com", age=3) + assert await u.save() + rob = await User.nodes.get(email="robin@test.com") + assert rob.email == "robin@test.com" + assert rob.age == 3 + + rob = await User.nodes.get_or_none(email="robin@test.com") + assert rob.email == "robin@test.com" + + n = await User.nodes.get_or_none(email="robin@nothere.com") + assert n is None + + +@mark_async_test +async def test_first_and_first_or_none(): + u = User(email="matt@test.com", age=24) + assert await u.save() + u2 = User(email="tbrady@test.com", age=40) + assert await u2.save() + tbrady = await User.nodes.order_by("-age").first() + assert tbrady.email == "tbrady@test.com" + assert tbrady.age == 40 + + tbrady = await User.nodes.order_by("-age").first_or_none() + assert tbrady.email == "tbrady@test.com" + + n = await User.nodes.first_or_none(email="matt@nothere.com") + assert n is None + + +def test_bare_init_without_save(): + """ + If a node model is initialised without being saved, accessing its `element_id` should + return None. + """ + assert User().element_id is None + + +@mark_async_test +async def test_save_to_model(): + u = User(email="jim@test.com", age=3) + assert await u.save() + assert u.element_id is not None + assert u.email == "jim@test.com" + assert u.age == 3 + + +@mark_async_test +async def test_save_node_without_properties(): + n = NodeWithoutProperty() + assert await n.save() + assert n.element_id is not None + + +@mark_async_test +async def test_unique(): + await adb.install_labels(User) + await User(email="jim1@test.com", age=3).save() + with raises(UniqueProperty): + await User(email="jim1@test.com", age=3).save() + + +@mark_async_test +async def test_update_unique(): + u = await User(email="jimxx@test.com", age=3).save() + await u.save() # this shouldn't fail + + +@mark_async_test +async def test_update(): + user = await User(email="jim2@test.com", age=3).save() + assert user + user.email = "jim2000@test.com" + await user.save() + jim = await User.nodes.get(email="jim2000@test.com") + assert jim + assert jim.email == "jim2000@test.com" + + +@mark_async_test +async def test_save_through_magic_property(): + user = await User(email_alias="blah@test.com", age=8).save() + assert user.email_alias == "blah@test.com" + user = await User.nodes.get(email="blah@test.com") + assert user.email == "blah@test.com" + assert user.email_alias == "blah@test.com" + + user1 = await User(email="blah1@test.com", age=8).save() + assert user1.email_alias == "blah1@test.com" + user1.email_alias = "blah2@test.com" + assert await user1.save() + user2 = await User.nodes.get(email="blah2@test.com") + assert user2 + + +class Customer2(AsyncStructuredNode): + __label__ = "customers" + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +@mark_async_test +async def test_not_updated_on_unique_error(): + await adb.install_labels(Customer2) + await Customer2(email="jim@bob.com", age=7).save() + test = await Customer2(email="jim1@bob.com", age=2).save() + test.email = "jim@bob.com" + with raises(UniqueProperty): + await test.save() + customers = await Customer2.nodes.all() + assert customers[0].email != customers[1].email + assert (await Customer2.nodes.get(email="jim@bob.com")).age == 7 + assert (await Customer2.nodes.get(email="jim1@bob.com")).age == 2 + + +@mark_async_test +async def test_label_not_inherited(): + class Customer3(Customer2): + address = StringProperty() + + assert Customer3.__label__ == "Customer3" + c = await Customer3(email="test@test.com").save() + assert "customers" in await c.labels() + assert "Customer3" in await c.labels() + + c = await Customer2.nodes.get(email="test@test.com") + assert isinstance(c, Customer2) + assert "customers" in await c.labels() + assert "Customer3" in await c.labels() + + +@mark_async_test +async def test_refresh(): + c = await Customer2(email="my@email.com", age=16).save() + c.my_custom_prop = "value" + copy = await Customer2.nodes.get(email="my@email.com") + copy.age = 20 + await copy.save() + + assert c.age == 16 + + await c.refresh() + assert c.age == 20 + assert c.my_custom_prop == "value" + + c = Customer2.inflate(c.element_id) + c.age = 30 + await c.refresh() + + assert c.age == 20 + + if adb.database_version.startswith("4"): + c = Customer2.inflate(999) + else: + c = Customer2.inflate("4:xxxxxx:999") + with raises(Customer2.DoesNotExist): + await c.refresh() + + +@mark_async_test +async def test_setting_value_to_none(): + c = await Customer2(email="alice@bob.com", age=42).save() + assert c.age is not None + + c.age = None + await c.save() + + copy = await Customer2.nodes.get(email="alice@bob.com") + assert copy.age is None + + +@mark_async_test +async def test_inheritance(): + class User(AsyncStructuredNode): + __abstract_node__ = True + name = StringProperty(unique_index=True) + + class Shopper(User): + balance = IntegerProperty(index=True) + + async def credit_account(self, amount): + self.balance = self.balance + int(amount) + await self.save() + + jim = await Shopper(name="jimmy", balance=300).save() + await jim.credit_account(50) + + assert Shopper.__label__ == "Shopper" + assert jim.balance == 350 + assert len(jim.inherited_labels()) == 1 + assert len(await jim.labels()) == 1 + assert (await jim.labels())[0] == "Shopper" + + +@mark_async_test +async def test_inherited_optional_labels(): + class BaseOptional(AsyncStructuredNode): + __optional_labels__ = ["Alive"] + name = StringProperty(unique_index=True) + + class ExtendedOptional(BaseOptional): + __optional_labels__ = ["RewardsMember"] + balance = IntegerProperty(index=True) + + async def credit_account(self, amount): + self.balance = self.balance + int(amount) + await self.save() + + henry = await ExtendedOptional(name="henry", balance=300).save() + await henry.credit_account(50) + + assert ExtendedOptional.__label__ == "ExtendedOptional" + assert henry.balance == 350 + assert len(henry.inherited_labels()) == 2 + assert len(await henry.labels()) == 2 + + assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"} + + +@mark_async_test +async def test_mixins(): + class UserMixin: + name = StringProperty(unique_index=True) + password = StringProperty() + + class CreditMixin: + balance = IntegerProperty(index=True) + + async def credit_account(self, amount): + self.balance = self.balance + int(amount) + await self.save() + + class Shopper2(AsyncStructuredNode, UserMixin, CreditMixin): + pass + + jim = await Shopper2(name="jimmy", balance=300).save() + await jim.credit_account(50) + + assert Shopper2.__label__ == "Shopper2" + assert jim.balance == 350 + assert len(jim.inherited_labels()) == 1 + assert len(await jim.labels()) == 1 + assert (await jim.labels())[0] == "Shopper2" + + +@mark_async_test +async def test_date_property(): + class DateTest(AsyncStructuredNode): + birthdate = DateProperty() + + user = await DateTest(birthdate=datetime.now()).save() + + +def test_reserved_property_keys(): + error_match = r".*is not allowed as it conflicts with neomodel internals.*" + with raises(ValueError, match=error_match): + + class ReservedPropertiesDeletedNode(AsyncStructuredNode): + deleted = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesIdNode(AsyncStructuredNode): + id = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesElementIdNode(AsyncStructuredNode): + element_id = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesIdRel(AsyncStructuredRel): + id = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesElementIdRel(AsyncStructuredRel): + element_id = StringProperty() + + error_match = r"Property names 'source' and 'target' are not allowed as they conflict with neomodel internals." + with raises(ValueError, match=error_match): + + class ReservedPropertiesSourceRel(AsyncStructuredRel): + source = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesTargetRel(AsyncStructuredRel): + target = StringProperty() diff --git a/test/async_/test_multiprocessing.py b/test/async_/test_multiprocessing.py new file mode 100644 index 00000000..101126e7 --- /dev/null +++ b/test/async_/test_multiprocessing.py @@ -0,0 +1,25 @@ +from multiprocessing.pool import ThreadPool as Pool +from test._async_compat import mark_async_test + +from neomodel import AsyncStructuredNode, StringProperty +from neomodel.async_.core import adb + + +class ThingyMaBob(AsyncStructuredNode): + name = StringProperty(unique_index=True, required=True) + + +async def thing_create(name): + name = str(name) + (thing,) = await ThingyMaBob.get_or_create({"name": name}) + return thing.name, name + + +@mark_async_test +async def test_concurrency(): + with Pool(5) as p: + results = p.map(thing_create, range(50)) + for to_unpack in results: + returned, sent = await to_unpack + assert returned == sent + await adb.close_connection() diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py new file mode 100644 index 00000000..ba664203 --- /dev/null +++ b/test/async_/test_paths.py @@ -0,0 +1,95 @@ +from test._async_compat import mark_async_test +from neomodel import ( + AsyncNeomodelPath, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, + UniqueIdProperty, +) +from neomodel.async_.core import adb + + +class PersonLivesInCity(AsyncStructuredRel): + """ + Relationship with data that will be instantiated as "stand-alone" + """ + + some_num = IntegerProperty(index=True, default=12) + + +class CountryOfOrigin(AsyncStructuredNode): + code = StringProperty(unique_index=True, required=True) + + +class CityOfResidence(AsyncStructuredNode): + name = StringProperty(required=True) + country = AsyncRelationshipTo(CountryOfOrigin, "FROM_COUNTRY") + + +class PersonOfInterest(AsyncStructuredNode): + uid = UniqueIdProperty() + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True, default=0) + + country = AsyncRelationshipTo(CountryOfOrigin, "IS_FROM") + city = AsyncRelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) + + +@mark_async_test +async def test_path_instantiation(): + """ + Neo4j driver paths should be instantiated as neomodel paths, with all of + their nodes and relationships resolved to their Python objects wherever + such a mapping is available. + """ + + c1 = await CountryOfOrigin(code="GR").save() + c2 = await CountryOfOrigin(code="FR").save() + + ct1 = await CityOfResidence(name="Athens", country=c1).save() + ct2 = await CityOfResidence(name="Paris", country=c2).save() + + p1 = await PersonOfInterest(name="Bill", age=22).save() + await p1.country.connect(c1) + await p1.city.connect(ct1) + + p2 = await PersonOfInterest(name="Jean", age=28).save() + await p2.country.connect(c2) + await p2.city.connect(ct2) + + p3 = await PersonOfInterest(name="Bo", age=32).save() + await p3.country.connect(c1) + await p3.city.connect(ct2) + + p4 = await PersonOfInterest(name="Drop", age=16).save() + await p4.country.connect(c1) + await p4.city.connect(ct2) + + # Retrieve a single path + q = await adb.cypher_query( + "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + resolve_objects=True, + ) + + path_object = q[0][0][0] + path_nodes = path_object.nodes + path_rels = path_object.relationships + + assert type(path_object) is AsyncNeomodelPath + assert type(path_nodes[0]) is CityOfResidence + assert type(path_nodes[1]) is PersonOfInterest + assert type(path_nodes[2]) is CountryOfOrigin + + assert type(path_rels[0]) is PersonLivesInCity + assert type(path_rels[1]) is AsyncStructuredRel + + await c1.delete() + await c2.delete() + await ct1.delete() + await ct2.delete() + await p1.delete() + await p2.delete() + await p3.delete() + await p4.delete() diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py new file mode 100644 index 00000000..65239fac --- /dev/null +++ b/test/async_/test_properties.py @@ -0,0 +1,455 @@ +from datetime import date, datetime + +from test._async_compat import mark_async_test +from pytest import mark, raises +from pytz import timezone + +from neomodel import AsyncStructuredNode +from neomodel.async_.core import adb +from neomodel.exceptions import ( + DeflateError, + InflateError, + RequiredProperty, + UniqueProperty, +) +from neomodel.properties import ( + ArrayProperty, + DateProperty, + DateTimeFormatProperty, + DateTimeProperty, + EmailProperty, + IntegerProperty, + JSONProperty, + NormalizedProperty, + RegexProperty, + StringProperty, + UniqueIdProperty, +) +from neomodel.util import _get_node_properties + + +class FooBar: + pass + + +def test_string_property_exceeds_max_length(): + """ + StringProperty is defined by two properties: `max_length` and `choices` that are mutually exclusive. Furthermore, + max_length must be a positive non-zero number. + """ + # Try to define a property that has both choices and max_length + with raises(ValueError): + some_string_property = StringProperty( + choices={"One": "1", "Two": "2"}, max_length=22 + ) + + # Try to define a string property that has a negative zero length + with raises(ValueError): + another_string_property = StringProperty(max_length=-35) + + # Try to validate a long string + a_string_property = StringProperty(required=True, max_length=5) + with raises(ValueError): + a_string_property.normalize("The quick brown fox jumps over the lazy dog") + + # Try to validate a "valid" string, as per the max_length setting. + valid_string = "Owen" + normalised_string = a_string_property.normalize(valid_string) + assert ( + valid_string == normalised_string + ), "StringProperty max_length test passed but values do not match." + + +@mark_async_test +async def test_string_property_w_choice(): + class TestChoices(AsyncStructuredNode): + SEXES = {"F": "Female", "M": "Male", "O": "Other"} + sex = StringProperty(required=True, choices=SEXES) + + try: + await TestChoices(sex="Z").save() + except DeflateError as e: + assert "choice" in str(e) + else: + assert False, "DeflateError not raised." + + node = await TestChoices(sex="M").save() + assert node.get_sex_display() == "Male" + + +def test_deflate_inflate(): + prop = IntegerProperty(required=True) + prop.name = "age" + prop.owner = FooBar + + try: + prop.inflate("six") + except InflateError as e: + assert True + assert "inflate property" in str(e) + else: + assert False, "DeflateError not raised." + + try: + prop.deflate("six") + except DeflateError as e: + assert "deflate property" in str(e) + else: + assert False, "DeflateError not raised." + + +def test_datetimes_timezones(): + prop = DateTimeProperty() + prop.name = "foo" + prop.owner = FooBar + t = datetime.utcnow() + gr = timezone("Europe/Athens") + gb = timezone("Europe/London") + dt1 = gr.localize(t) + dt2 = gb.localize(t) + time1 = prop.inflate(prop.deflate(dt1)) + time2 = prop.inflate(prop.deflate(dt2)) + assert time1.utctimetuple() == dt1.utctimetuple() + assert time1.utctimetuple() < time2.utctimetuple() + assert time1.tzname() == "UTC" + + +def test_date(): + prop = DateProperty() + prop.name = "foo" + prop.owner = FooBar + somedate = date(2012, 12, 15) + assert prop.deflate(somedate) == "2012-12-15" + assert prop.inflate("2012-12-15") == somedate + + +def test_datetime_format(): + some_format = "%Y-%m-%d %H:%M:%S" + prop = DateTimeFormatProperty(format=some_format) + prop.name = "foo" + prop.owner = FooBar + some_datetime = datetime(2019, 3, 19, 15, 36, 25) + assert prop.deflate(some_datetime) == "2019-03-19 15:36:25" + assert prop.inflate("2019-03-19 15:36:25") == some_datetime + + +def test_datetime_exceptions(): + prop = DateTimeProperty() + prop.name = "created" + prop.owner = FooBar + faulty = "dgdsg" + + try: + prop.inflate(faulty) + except InflateError as e: + assert "inflate property" in str(e) + else: + assert False, "InflateError not raised." + + try: + prop.deflate(faulty) + except DeflateError as e: + assert "deflate property" in str(e) + else: + assert False, "DeflateError not raised." + + +def test_date_exceptions(): + prop = DateProperty() + prop.name = "date" + prop.owner = FooBar + faulty = "2012-14-13" + + try: + prop.inflate(faulty) + except InflateError as e: + assert "inflate property" in str(e) + else: + assert False, "InflateError not raised." + + try: + prop.deflate(faulty) + except DeflateError as e: + assert "deflate property" in str(e) + else: + assert False, "DeflateError not raised." + + +def test_json(): + prop = JSONProperty() + prop.name = "json" + prop.owner = FooBar + + value = {"test": [1, 2, 3]} + + assert prop.deflate(value) == '{"test": [1, 2, 3]}' + assert prop.inflate('{"test": [1, 2, 3]}') == value + + +@mark_async_test +async def test_default_value(): + class DefaultTestValue(AsyncStructuredNode): + name_xx = StringProperty(default="jim", index=True) + + a = DefaultTestValue() + assert a.name_xx == "jim" + await a.save() + + +@mark_async_test +async def test_default_value_callable(): + def uid_generator(): + return "xx" + + class DefaultTestValueTwo(AsyncStructuredNode): + uid = StringProperty(default=uid_generator, index=True) + + a = await DefaultTestValueTwo().save() + assert a.uid == "xx" + + +@mark_async_test +async def test_default_value_callable_type(): + # check our object gets converted to str without serializing and reload + def factory(): + class Foo: + def __str__(self): + return "123" + + return Foo() + + class DefaultTestValueThree(AsyncStructuredNode): + uid = StringProperty(default=factory, index=True) + + x = DefaultTestValueThree() + assert x.uid == "123" + await x.save() + assert x.uid == "123" + await x.refresh() + assert x.uid == "123" + + +@mark_async_test +async def test_independent_property_name(): + class TestDBNamePropertyNode(AsyncStructuredNode): + name_ = StringProperty(db_property="name") + + x = TestDBNamePropertyNode() + x.name_ = "jim" + await x.save() + + # check database property name on low level + results, meta = await adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") + node_properties = _get_node_properties(results[0][0]) + assert node_properties["name"] == "jim" + + node_properties = _get_node_properties(results[0][0]) + assert not "name_" in node_properties + assert not hasattr(x, "name") + assert hasattr(x, "name_") + assert (await TestDBNamePropertyNode.nodes.filter(name_="jim").all())[ + 0 + ].name_ == x.name_ + assert (await TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ + + await x.delete() + + +@mark_async_test +async def test_independent_property_name_get_or_create(): + class TestNode(AsyncStructuredNode): + uid = UniqueIdProperty() + name_ = StringProperty(db_property="name", required=True) + + # create the node + await TestNode.get_or_create({"uid": 123, "name_": "jim"}) + # test that the node is retrieved correctly + x = (await TestNode.get_or_create({"uid": 123, "name_": "jim"}))[0] + + # check database property name on low level + results, _ = await adb.cypher_query("MATCH (n:TestNode) RETURN n") + node_properties = _get_node_properties(results[0][0]) + assert node_properties["name"] == "jim" + assert "name_" not in node_properties + + # delete node afterwards + await x.delete() + + +@mark.parametrize("normalized_class", (NormalizedProperty,)) +def test_normalized_property(normalized_class): + class TestProperty(normalized_class): + def normalize(self, value): + self._called_with = value + self._called = True + return value + "bar" + + inflate = TestProperty() + inflate_res = inflate.inflate("foo") + assert getattr(inflate, "_called", False) + assert getattr(inflate, "_called_with", None) == "foo" + assert inflate_res == "foobar" + + deflate = TestProperty() + deflate_res = deflate.deflate("bar") + assert getattr(deflate, "_called", False) + assert getattr(deflate, "_called_with", None) == "bar" + assert deflate_res == "barbar" + + default = TestProperty(default="qux") + default_res = default.default_value() + assert getattr(default, "_called", False) + assert getattr(default, "_called_with", None) == "qux" + assert default_res == "quxbar" + + +def test_regex_property(): + class MissingExpression(RegexProperty): + pass + + with raises(ValueError): + MissingExpression() + + class TestProperty(RegexProperty): + name = "test" + owner = object() + expression = r"\w+ \w+$" + + def normalize(self, value): + self._called = True + return super().normalize(value) + + prop = TestProperty() + result = prop.inflate("foo bar") + assert getattr(prop, "_called", False) + assert result == "foo bar" + + with raises(DeflateError): + prop.deflate("qux") + + +def test_email_property(): + prop = EmailProperty() + prop.name = "email" + prop.owner = object() + result = prop.inflate("foo@example.com") + assert result == "foo@example.com" + + with raises(DeflateError): + prop.deflate("foo@example") + + +@mark_async_test +async def test_uid_property(): + prop = UniqueIdProperty() + prop.name = "uid" + prop.owner = object() + myuid = prop.default_value() + assert len(myuid) + + class CheckMyId(AsyncStructuredNode): + uid = UniqueIdProperty() + + cmid = await CheckMyId().save() + assert len(cmid.uid) + + +class ArrayProps(AsyncStructuredNode): + uid = StringProperty(unique_index=True) + untyped_arr = ArrayProperty() + typed_arr = ArrayProperty(IntegerProperty()) + + +@mark_async_test +async def test_array_properties(): + # untyped + ap1 = await ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save() + assert "Tim" in ap1.untyped_arr + ap1 = await ArrayProps.nodes.get(uid="1") + assert "Tim" in ap1.untyped_arr + + # typed + try: + await ArrayProps(uid="2", typed_arr=["a", "b"]).save() + except DeflateError as e: + assert "unsaved node" in str(e) + else: + assert False, "DeflateError not raised." + + ap2 = await ArrayProps(uid="2", typed_arr=[1, 2]).save() + assert 1 in ap2.typed_arr + ap2 = await ArrayProps.nodes.get(uid="2") + assert 2 in ap2.typed_arr + + +def test_illegal_array_base_prop_raises(): + with raises(ValueError): + ArrayProperty(StringProperty(index=True)) + + +@mark_async_test +async def test_indexed_array(): + class IndexArray(AsyncStructuredNode): + ai = ArrayProperty(unique_index=True) + + b = await IndexArray(ai=[1, 2]).save() + c = await IndexArray.nodes.get(ai=[1, 2]) + assert b.element_id == c.element_id + + +@mark_async_test +async def test_unique_index_prop_not_required(): + class ConstrainedTestNode(AsyncStructuredNode): + required_property = StringProperty(required=True) + unique_property = StringProperty(unique_index=True) + unique_required_property = StringProperty(unique_index=True, required=True) + unconstrained_property = StringProperty() + + # Create a node with a missing required property + with raises(RequiredProperty): + x = ConstrainedTestNode(required_property="required", unique_property="unique") + await x.save() + + # Create a node with a missing unique (but not required) property. + x = ConstrainedTestNode() + x.required_property = "required" + x.unique_required_property = "unique and required" + x.unconstrained_property = "no contraints" + await x.save() + + # check database property name on low level + results, meta = await adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") + node_properties = _get_node_properties(results[0][0]) + assert node_properties["unique_required_property"] == "unique and required" + + # delete node afterwards + await x.delete() + + +@mark_async_test +async def test_unique_index_prop_enforced(): + class UniqueNullableNameNode(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + await adb.install_labels(UniqueNullableNameNode) + # Nameless + x = UniqueNullableNameNode() + await x.save() + y = UniqueNullableNameNode() + await y.save() + + # Named + z = UniqueNullableNameNode(name="named") + await z.save() + with raises(UniqueProperty): + a = UniqueNullableNameNode(name="named") + await a.save() + + # Check nodes are in database + results, _ = await adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") + assert len(results) == 3 + + # Delete nodes afterwards + await x.delete() + await y.delete() + await z.delete() diff --git a/test/async_/test_relationship_models.py b/test/async_/test_relationship_models.py new file mode 100644 index 00000000..b360a881 --- /dev/null +++ b/test/async_/test_relationship_models.py @@ -0,0 +1,166 @@ +from datetime import datetime +from test._async_compat import mark_async_test + +import pytz +from pytest import raises + +from neomodel import ( + AsyncRelationship, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + DateTimeProperty, + DeflateError, + StringProperty, +) + +HOOKS_CALLED = {"pre_save": 0, "post_save": 0} + + +class FriendRel(AsyncStructuredRel): + since = DateTimeProperty(default=lambda: datetime.now(pytz.utc)) + + +class HatesRel(FriendRel): + reason = StringProperty() + + def pre_save(self): + HOOKS_CALLED["pre_save"] += 1 + + def post_save(self): + HOOKS_CALLED["post_save"] += 1 + + +class Badger(AsyncStructuredNode): + name = StringProperty(unique_index=True) + friend = AsyncRelationship("Badger", "FRIEND", model=FriendRel) + hates = AsyncRelationshipTo("Stoat", "HATES", model=HatesRel) + + +class Stoat(AsyncStructuredNode): + name = StringProperty(unique_index=True) + hates = AsyncRelationshipTo("Badger", "HATES", model=HatesRel) + + +@mark_async_test +async def test_either_connect_with_rel_model(): + paul = await Badger(name="Paul").save() + tom = await Badger(name="Tom").save() + + # creating rels + new_rel = await tom.friend.disconnect(paul) + new_rel = await tom.friend.connect(paul) + assert isinstance(new_rel, FriendRel) + assert isinstance(new_rel.since, datetime) + + # updating properties + new_rel.since = datetime.now(pytz.utc) + assert isinstance(await new_rel.save(), FriendRel) + + # start and end nodes are the opposite of what you'd expect when using either.. + # I've tried everything possible to correct this to no avail + paul = await new_rel.start_node() + tom = await new_rel.end_node() + assert paul.name == "Tom" + assert tom.name == "Paul" + + +@mark_async_test +async def test_direction_connect_with_rel_model(): + paul = await Badger(name="Paul the badger").save() + ian = await Stoat(name="Ian the stoat").save() + + rel = await ian.hates.connect( + paul, {"reason": "thinks paul should bath more often"} + ) + assert isinstance(rel.since, datetime) + assert isinstance(rel, FriendRel) + assert rel.reason.startswith("thinks") + rel.reason = "he smells" + await rel.save() + + ian = await rel.start_node() + assert isinstance(ian, Stoat) + paul = await rel.end_node() + assert isinstance(paul, Badger) + + assert ian.name.startswith("Ian") + assert paul.name.startswith("Paul") + + rel = await ian.hates.relationship(paul) + assert isinstance(rel, HatesRel) + assert isinstance(rel.since, datetime) + await rel.save() + + # test deflate checking + rel.since = "2:30pm" + with raises(DeflateError): + await rel.save() + + # check deflate check via connect + with raises(DeflateError): + await paul.hates.connect( + ian, + { + "reason": "thinks paul should bath more often", + "since": "2:30pm", + }, + ) + + +@mark_async_test +async def test_traversal_where_clause(): + phill = await Badger(name="Phill the badger").save() + tim = await Badger(name="Tim the badger").save() + bob = await Badger(name="Bob the badger").save() + rel = await tim.friend.connect(bob) + now = datetime.now(pytz.utc) + assert rel.since < now + rel2 = await tim.friend.connect(phill) + assert rel2.since > now + friends = tim.friend.match(since__gt=now) + assert len(await friends.all()) == 1 + + +@mark_async_test +async def test_multiple_rels_exist_issue_223(): + # check a badger can dislike a stoat for multiple reasons + phill = await Badger(name="Phill").save() + ian = await Stoat(name="Stoat").save() + + rel_a = await phill.hates.connect(ian, {"reason": "a"}) + rel_b = await phill.hates.connect(ian, {"reason": "b"}) + assert rel_a.element_id != rel_b.element_id + + ian_a = await phill.hates.match(reason="a")[0] + ian_b = await phill.hates.match(reason="b")[0] + assert ian_a.element_id == ian_b.element_id + + +@mark_async_test +async def test_retrieve_all_rels(): + tom = await Badger(name="tom").save() + ian = await Stoat(name="ian").save() + + rel_a = await tom.hates.connect(ian, {"reason": "a"}) + rel_b = await tom.hates.connect(ian, {"reason": "b"}) + + rels = await tom.hates.all_relationships(ian) + assert len(rels) == 2 + assert rels[0].element_id in [rel_a.element_id, rel_b.element_id] + assert rels[1].element_id in [rel_a.element_id, rel_b.element_id] + + +@mark_async_test +async def test_save_hook_on_rel_model(): + HOOKS_CALLED["pre_save"] = 0 + HOOKS_CALLED["post_save"] = 0 + + paul = await Badger(name="PaulB").save() + ian = await Stoat(name="IanS").save() + + rel = await ian.hates.connect(paul, {"reason": "yadda yadda"}) + await rel.save() + + assert HOOKS_CALLED["pre_save"] == 2 + assert HOOKS_CALLED["post_save"] == 2 diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py new file mode 100644 index 00000000..40b8649c --- /dev/null +++ b/test/async_/test_relationships.py @@ -0,0 +1,209 @@ +from pytest import raises +from test._async_compat import mark_async_test + +from neomodel import ( + AsyncOne, + AsyncRelationship, + AsyncRelationshipFrom, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + Q, + StringProperty, +) +from neomodel.async_.core import adb + + +class PersonWithRels(AsyncStructuredNode): + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True) + is_from = AsyncRelationshipTo("Country", "IS_FROM") + knows = AsyncRelationship("PersonWithRels", "KNOWS") + + @property + def special_name(self): + return self.name + + def special_power(self): + return "I have no powers" + + +class Country(AsyncStructuredNode): + code = StringProperty(unique_index=True) + inhabitant = AsyncRelationshipFrom(PersonWithRels, "IS_FROM") + president = AsyncRelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne) + + +class SuperHero(PersonWithRels): + power = StringProperty(index=True) + + def special_power(self): + return "I have powers" + + +@mark_async_test +async def test_actions_on_deleted_node(): + u = await PersonWithRels(name="Jim2", age=3).save() + await u.delete() + with raises(ValueError): + await u.is_from.connect(None) + + with raises(ValueError): + await u.is_from.get() + + with raises(ValueError): + await u.save() + + +@mark_async_test +async def test_bidirectional_relationships(): + u = await PersonWithRels(name="Jim", age=3).save() + assert u + + de = await Country(code="DE").save() + assert de + + assert not await u.is_from.all() + + assert u.is_from.__class__.__name__ == "AsyncZeroOrMore" + await u.is_from.connect(de) + + assert len(await u.is_from.all()) == 1 + + assert await u.is_from.is_connected(de) + + b = (await u.is_from.all())[0] + assert b.__class__.__name__ == "Country" + assert b.code == "DE" + + s = (await b.inhabitant.all())[0] + assert s.name == "Jim" + + await u.is_from.disconnect(b) + assert not await u.is_from.is_connected(b) + + +@mark_async_test +async def test_either_direction_connect(): + rey = await PersonWithRels(name="Rey", age=3).save() + sakis = await PersonWithRels(name="Sakis", age=3).save() + + await rey.knows.connect(sakis) + assert await rey.knows.is_connected(sakis) + assert await sakis.knows.is_connected(rey) + await sakis.knows.connect(rey) + + result, _ = await sakis.cypher( + f"""MATCH (us), (them) + WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(them)=$them + MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", + {"them": rey.element_id}, + ) + assert int(result[0][0]) == 1 + + rel = await rey.knows.relationship(sakis) + assert isinstance(rel, AsyncStructuredRel) + + rels = await rey.knows.all_relationships(sakis) + assert isinstance(rels[0], AsyncStructuredRel) + + +# TODO : Make async-independent test to test .filter and not .filter.all() ? +@mark_async_test +async def test_search_and_filter_and_exclude(): + fred = await PersonWithRels(name="Fred", age=13).save() + zz = await Country(code="ZZ").save() + zx = await Country(code="ZX").save() + zt = await Country(code="ZY").save() + await fred.is_from.connect(zz) + await fred.is_from.connect(zx) + await fred.is_from.connect(zt) + result = fred.is_from.filter(code="ZX") + assert result[0].code == "ZX" + + result = fred.is_from.filter(code="ZY") + assert result[0].code == "ZY" + + result = fred.is_from.exclude(code="ZZ").exclude(code="ZY") + assert result[0].code == "ZX" and len(result) == 1 + + result = fred.is_from.exclude(Q(code__contains="Y")) + assert len(result) == 2 + + result = fred.is_from.filter(Q(code__contains="Z")) + assert len(result) == 3 + + +@mark_async_test +async def test_custom_methods(): + u = await PersonWithRels(name="Joe90", age=13).save() + assert u.special_power() == "I have no powers" + u = await SuperHero(name="Joe91", age=13, power="xxx").save() + assert u.special_power() == "I have powers" + assert u.special_name == "Joe91" + + +@mark_async_test +async def test_valid_reconnection(): + p = await PersonWithRels(name="ElPresidente", age=93).save() + assert p + + pp = await PersonWithRels(name="TheAdversary", age=33).save() + assert pp + + c = await Country(code="CU").save() + assert c + + await c.president.connect(p) + assert await c.president.is_connected(p) + + # the coup d'etat + await c.president.reconnect(p, pp) + assert await c.president.is_connected(pp) + + # reelection time + await c.president.reconnect(pp, pp) + assert await c.president.is_connected(pp) + + +@mark_async_test +async def test_valid_replace(): + brady = await PersonWithRels(name="Tom Brady", age=40).save() + assert brady + + gronk = await PersonWithRels(name="Rob Gronkowski", age=28).save() + assert gronk + + colbert = await PersonWithRels(name="Stephen Colbert", age=53).save() + assert colbert + + hanks = await PersonWithRels(name="Tom Hanks", age=61).save() + assert hanks + + await brady.knows.connect(gronk) + await brady.knows.connect(colbert) + assert len(await brady.knows.all()) == 2 + assert await brady.knows.is_connected(gronk) + assert await brady.knows.is_connected(colbert) + + await brady.knows.replace(hanks) + assert len(await brady.knows.all()) == 1 + assert await brady.knows.is_connected(hanks) + assert not await brady.knows.is_connected(gronk) + assert not await brady.knows.is_connected(colbert) + + +@mark_async_test +async def test_props_relationship(): + u = await PersonWithRels(name="Mar", age=20).save() + assert u + + c = await Country(code="AT").save() + assert c + + c2 = await Country(code="LA").save() + assert c2 + + with raises(NotImplementedError): + await c.inhabitant.connect(u, properties={"city": "Thessaloniki"}) diff --git a/test/test_relative_relationships.py b/test/async_/test_relative_relationships.py similarity index 63% rename from test/test_relative_relationships.py rename to test/async_/test_relative_relationships.py index 619d9d9e..5b8dbefa 100644 --- a/test/test_relative_relationships.py +++ b/test/async_/test_relative_relationships.py @@ -1,5 +1,6 @@ from neomodel import AsyncRelationshipTo, AsyncStructuredNode, StringProperty -from neomodel.test_relationships import Country +from test.async_.test_relationships import Country +from test._async_compat import mark_async_test class Cat(AsyncStructuredNode): @@ -8,14 +9,15 @@ class Cat(AsyncStructuredNode): is_from = AsyncRelationshipTo(".test_relationships.Country", "IS_FROM") -def test_relative_relationship(): - a = Cat(name="snufkin").save() +@mark_async_test +async def test_relative_relationship(): + a = await Cat(name="snufkin").save() assert a - c = Country(code="MG").save() + c = await Country(code="MG").save() assert c # connecting an instance of the class defined above # the next statement will fail if there's a type mismatch - a.is_from.connect(c) + await a.is_from.connect(c) assert a.is_from.is_connected(c) diff --git a/test/sync/__init__.py b/test/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_match_api.py b/test/sync/test_match_api.py similarity index 82% rename from test/test_match_api.py rename to test/sync/test_match_api.py index 618528cf..8fd02eb6 100644 --- a/test/test_match_api.py +++ b/test/sync/test_match_api.py @@ -1,62 +1,68 @@ from datetime import datetime +from test._async_compat import mark_sync_test from pytest import raises from neomodel import ( INCOMING, - AsyncRelationshipFrom, - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, + RelationshipFrom, + RelationshipTo, + StructuredNode, + StructuredRel, DateTimeProperty, IntegerProperty, Q, StringProperty, ) -from neomodel.async_.match import ( - AsyncNodeSet, - AsyncQueryBuilder, - AsyncTraversal, +from neomodel.sync_.match import ( + NodeSet, + QueryBuilder, + Traversal, Optional, ) from neomodel.exceptions import MultipleNodesReturned -class SupplierRel(AsyncStructuredRel): +class SupplierRel(StructuredRel): since = DateTimeProperty(default=datetime.now) courier = StringProperty() -class Supplier(AsyncStructuredNode): +class Supplier(StructuredNode): name = StringProperty() delivery_cost = IntegerProperty() - coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS") + coffees = RelationshipTo("Coffee", "COFFEE SUPPLIERS") -class Species(AsyncStructuredNode): +class Species(StructuredNode): name = StringProperty() - coffees = AsyncRelationshipFrom( - "Coffee", "COFFEE SPECIES", model=AsyncStructuredRel + coffees = RelationshipFrom( + "Coffee", "COFFEE SPECIES", model=StructuredRel ) -class Coffee(AsyncStructuredNode): +class Coffee(StructuredNode): name = StringProperty(unique_index=True) price = IntegerProperty() - suppliers = AsyncRelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) - species = AsyncRelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel) + suppliers = RelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) + species = RelationshipTo(Species, "COFFEE SPECIES", model=StructuredRel) id_ = IntegerProperty() -class Extension(AsyncStructuredNode): - extension = AsyncRelationshipTo("Extension", "extension") +class Extension(StructuredNode): + extension = RelationshipTo("Extension", "extension") +# TODO : Maybe split these tests into separate async and sync (not transpiled) +# That would allow to test "Coffee.nodes" for sync instead of Coffee.nodes.all() + + +@mark_sync_test def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() - node_set = AsyncNodeSet(Coffee) - qb = AsyncQueryBuilder(node_set).build_ast() + node_set = NodeSet(Coffee) + qb = QueryBuilder(node_set).build_ast() results = qb._execute() @@ -69,7 +75,7 @@ def test_filter_exclude_via_labels(): # with filter and exclude Coffee(name="Kenco", price=3).save() node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") - qb = AsyncQueryBuilder(node_set).build_ast() + qb = QueryBuilder(node_set).build_ast() results = qb._execute() assert "(coffee:Coffee)" in qb._ast.match @@ -78,26 +84,28 @@ def test_filter_exclude_via_labels(): assert results[0].name == "Kenco" +@mark_sync_test def test_simple_has_via_label(): nescafe = Coffee(name="Nescafe", price=99).save() tesco = Supplier(name="Tesco", delivery_cost=2).save() nescafe.suppliers.connect(tesco) - ns = AsyncNodeSet(Coffee).has(suppliers=True) - qb = AsyncQueryBuilder(ns).build_ast() + ns = NodeSet(Coffee).has(suppliers=True) + qb = QueryBuilder(ns).build_ast() results = qb._execute() assert "COFFEE SUPPLIERS" in qb._ast.where[0] assert len(results) == 1 assert results[0].name == "Nescafe" Coffee(name="nespresso", price=99).save() - ns = AsyncNodeSet(Coffee).has(suppliers=False) - qb = AsyncQueryBuilder(ns).build_ast() + ns = NodeSet(Coffee).has(suppliers=False) + qb = QueryBuilder(ns).build_ast() results = qb._execute() assert len(results) > 0 assert "NOT" in qb._ast.where[0] +@mark_sync_test def test_get(): Coffee(name="1", price=3).save() assert Coffee.nodes.get(name="1") @@ -111,13 +119,14 @@ def test_get(): Coffee.nodes.get(price=3) +@mark_sync_test def test_simple_traverse_with_filter(): nescafe = Coffee(name="Nescafe2", price=99).save() tesco = Supplier(name="Sainsburys", delivery_cost=2).save() nescafe.suppliers.connect(tesco) - qb = AsyncQueryBuilder( - AsyncNodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) + qb = QueryBuilder( + NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) ) results = qb.build_ast()._execute() @@ -129,14 +138,15 @@ def test_simple_traverse_with_filter(): assert results[0].name == "Sainsburys" +@mark_sync_test def test_double_traverse(): nescafe = Coffee(name="Nescafe plus", price=99).save() tesco = Supplier(name="Asda", delivery_cost=2).save() nescafe.suppliers.connect(tesco) tesco.coffees.connect(Coffee(name="Decafe", price=2).save()) - ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match() - qb = AsyncQueryBuilder(ns).build_ast() + ns = NodeSet(NodeSet(source=nescafe).suppliers.match()).coffees.match() + qb = QueryBuilder(ns).build_ast() results = qb._execute() assert len(results) == 2 @@ -144,41 +154,46 @@ def test_double_traverse(): assert results[1].name == "Nescafe plus" +@mark_sync_test def test_count(): Coffee(name="Nescafe Gold", price=99).save() - count = AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast()._count() + count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count() assert count > 0 Coffee(name="Kawa", price=27).save() - node_set = AsyncNodeSet(source=Coffee) + node_set = NodeSet(source=Coffee) node_set.skip = 1 node_set.limit = 1 - count = AsyncQueryBuilder(node_set).build_ast()._count() + count = QueryBuilder(node_set).build_ast()._count() assert count == 1 +@mark_sync_test def test_len_and_iter_and_bool(): iterations = 0 Coffee(name="Icelands finest").save() - for c in Coffee.nodes: + for c in Coffee.nodes.all(): iterations += 1 c.delete() assert iterations > 0 - assert len(Coffee.nodes) == 0 + assert len(Coffee.nodes.all()) == 0 +@mark_sync_test def test_slice(): - for c in Coffee.nodes: + for c in Coffee.nodes.all(): c.delete() Coffee(name="Icelands finest").save() Coffee(name="Britains finest").save() Coffee(name="Japans finest").save() + # TODO : Make slice work with async + # Doing await (Coffee.nodes.all())[1:] fetches without slicing assert len(list(Coffee.nodes.all()[1:])) == 2 assert len(list(Coffee.nodes.all()[:1])) == 1 assert isinstance(Coffee.nodes[1], Coffee) @@ -186,6 +201,7 @@ def test_slice(): assert len(list(Coffee.nodes.all()[1:2])) == 1 +@mark_sync_test def test_issue_208(): # calls to match persist across queries. @@ -196,10 +212,11 @@ def test_issue_208(): b.suppliers.connect(l, {"courier": "fedex"}) b.suppliers.connect(a, {"courier": "dhl"}) - assert len(b.suppliers.match(courier="fedex")) - assert len(b.suppliers.match(courier="dhl")) + assert len(b.suppliers.match(courier="fedex").all()) + assert len(b.suppliers.match(courier="dhl").all()) +@mark_sync_test def test_issue_589(): node1 = Extension().save() node2 = Extension().save() @@ -207,12 +224,14 @@ def test_issue_589(): assert node2 in node1.extension.all() +# TODO : Fix the ValueError not raised +@mark_sync_test def test_contains(): expensive = Coffee(price=1000, name="Pricey").save() asda = Coffee(name="Asda", price=1).save() - assert expensive in Coffee.nodes.filter(price__gt=999) - assert asda not in Coffee.nodes.filter(price__gt=999) + assert expensive in Coffee.nodes.filter(price__gt=999).all() + assert asda not in Coffee.nodes.filter(price__gt=999).all() # bad value raises with raises(ValueError): @@ -223,25 +242,26 @@ def test_contains(): Coffee() in Coffee.nodes +@mark_sync_test def test_order_by(): - for c in Coffee.nodes: + for c in Coffee.nodes.all(): c.delete() c1 = Coffee(name="Icelands finest", price=5).save() c2 = Coffee(name="Britains finest", price=10).save() c3 = Coffee(name="Japans finest", price=35).save() - assert Coffee.nodes.order_by("price").all()[0].price == 5 - assert Coffee.nodes.order_by("-price").all()[0].price == 35 + assert Coffee.nodes.order_by("price")[0].price == 5 + assert Coffee.nodes.order_by("-price")[0].price == 35 ns = Coffee.nodes.order_by("-price") - qb = AsyncQueryBuilder(ns).build_ast() + qb = QueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) - qb = AsyncQueryBuilder(ns).build_ast() + qb = QueryBuilder(ns).build_ast() assert not qb._ast.order_by ns = ns.order_by("?") - qb = AsyncQueryBuilder(ns).build_ast() + qb = QueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" assert qb._ast.order_by == "r" @@ -263,8 +283,9 @@ def test_order_by(): assert ordered_n[2] == c3 +@mark_sync_test def test_extra_filters(): - for c in Coffee.nodes: + for c in Coffee.nodes.all(): c.delete() c1 = Coffee(name="Icelands finest", price=5, id_=1).save() @@ -303,7 +324,7 @@ def test_traversal_definition_keys_are_valid(): muckefuck = Coffee(name="Mukkefuck", price=1) with raises(ValueError): - AsyncTraversal( + Traversal( muckefuck, "a_name", { @@ -314,7 +335,7 @@ def test_traversal_definition_keys_are_valid(): }, ) - AsyncTraversal( + Traversal( muckefuck, "a_name", { @@ -326,6 +347,7 @@ def test_traversal_definition_keys_are_valid(): ) +@mark_sync_test def test_empty_filters(): """Test this case: ``` @@ -338,7 +360,7 @@ def test_empty_filters(): ``NodeSet`` object. """ - for c in Coffee.nodes: + for c in Coffee.nodes.all(): c.delete() c1 = Coffee(name="Super", price=5, id_=1).save() @@ -362,9 +384,10 @@ def test_empty_filters(): ), "doesnt contain c1 in ``filter_empty_filter``" +@mark_sync_test def test_q_filters(): # Test where no children and self.connector != conn ? - for c in Coffee.nodes: + for c in Coffee.nodes.all(): c.delete() c1 = Coffee(name="Icelands finest", price=5, id_=1).save() @@ -427,7 +450,7 @@ def test_q_filters(): combined_coffees = Coffee.nodes.filter( Q(price=35), Q(name="Latte") | Q(name="Cappuccino") - ) + ).all() assert len(combined_coffees) == 2 assert c5 in combined_coffees assert c6 in combined_coffees @@ -452,6 +475,7 @@ def test_qbase(): assert len(test_hash) == 1 +@mark_sync_test def test_traversal_filter_left_hand_statement(): nescafe = Coffee(name="Nescafe2", price=99).save() nescafe_gold = Coffee(name="Nescafe gold", price=11).save() @@ -465,7 +489,7 @@ def test_traversal_filter_left_hand_statement(): nescafe_gold.suppliers.connect(lidl) lidl_supplier = ( - AsyncNodeSet(Coffee.nodes.filter(price=11).suppliers) + NodeSet(Coffee.nodes.filter(price=11).suppliers) .filter(delivery_cost=3) .all() ) @@ -473,6 +497,7 @@ def test_traversal_filter_left_hand_statement(): assert lidl in lidl_supplier +@mark_sync_test def test_fetch_relations(): arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() @@ -510,6 +535,9 @@ def test_fetch_relations(): ) assert count == 1 - assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( - name="Sainsburys" + assert ( + tesco + in Supplier.nodes.fetch_relations("coffees__species") + .filter(name="Sainsburys") + .all() ) diff --git a/test/sync/test_migration_neo4j_5.py b/test/sync/test_migration_neo4j_5.py new file mode 100644 index 00000000..3198abf8 --- /dev/null +++ b/test/sync/test_migration_neo4j_5.py @@ -0,0 +1,78 @@ +import pytest +from test._async_compat import mark_sync_test + +from neomodel import ( + RelationshipTo, + StructuredNode, + StructuredRel, + IntegerProperty, + StringProperty, +) +from neomodel.sync_.core import db + + +class Album(StructuredNode): + name = StringProperty() + + +class Released(StructuredRel): + year = IntegerProperty() + + +class Band(StructuredNode): + name = StringProperty() + released = RelationshipTo(Album, relation_type="RELEASED", model=Released) + + +@mark_sync_test +def test_read_elements_id(): + the_hives = Band(name="The Hives").save() + lex_hives = Album(name="Lex Hives").save() + released_rel = the_hives.released.connect(lex_hives) + + # Validate element_id properties + assert lex_hives.element_id == (the_hives.released.single()).element_id + assert released_rel._start_node_element_id == the_hives.element_id + assert released_rel._end_node_element_id == lex_hives.element_id + + # Validate id properties + # Behaviour is dependent on Neo4j version + if db.database_version.startswith("4"): + # Nodes' ids + assert lex_hives.id == int(lex_hives.element_id) + assert lex_hives.id == (the_hives.released.single()).id + # Relationships' ids + assert isinstance(released_rel.element_id, int) + assert released_rel.element_id == released_rel.id + assert released_rel._start_node_id == int(the_hives.element_id) + assert released_rel._end_node_id == int(lex_hives.element_id) + else: + # Nodes' ids + expected_error_type = ValueError + expected_error_message = "id is deprecated in Neo4j version 5, please migrate to element_id\. If you use the id in a Cypher query, replace id\(\) by elementId\(\)\." + assert isinstance(lex_hives.element_id, str) + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + lex_hives.id + + # Relationships' ids + assert isinstance(released_rel.element_id, str) + assert isinstance(released_rel._start_node_element_id, str) + assert isinstance(released_rel._end_node_element_id, str) + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + released_rel.id + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + released_rel._start_node_id + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + released_rel._end_node_id diff --git a/test/test_models.py b/test/sync/test_models.py similarity index 84% rename from test/test_models.py rename to test/sync/test_models.py index b781bb92..13db07fe 100644 --- a/test/test_models.py +++ b/test/sync/test_models.py @@ -2,20 +2,21 @@ from datetime import datetime +from test._async_compat import mark_sync_test from pytest import raises from neomodel import ( - AsyncStructuredNode, - AsyncStructuredRel, + StructuredNode, + StructuredRel, DateProperty, IntegerProperty, StringProperty, ) -from neomodel.async_.core import adb +from neomodel.sync_.core import db from neomodel.exceptions import RequiredProperty, UniqueProperty -class User(AsyncStructuredNode): +class User(StructuredNode): email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) @@ -28,12 +29,13 @@ def email_alias(self, value): self.email = value -class NodeWithoutProperty(AsyncStructuredNode): +class NodeWithoutProperty(StructuredNode): pass +@mark_sync_test def test_issue_233(): - class BaseIssue233(AsyncStructuredNode): + class BaseIssue233(StructuredNode): __abstract_node__ = True def __getitem__(self, item): @@ -51,13 +53,10 @@ def test_issue_72(): assert user.age is None +@mark_sync_test def test_required(): - try: + with raises(RequiredProperty): User(age=3).save() - except RequiredProperty: - assert True - else: - assert False def test_repr_and_str(): @@ -67,6 +66,7 @@ def test_repr_and_str(): assert True +@mark_sync_test def test_get_and_get_or_none(): u = User(email="robin@test.com", age=3) assert u.save() @@ -81,6 +81,7 @@ def test_get_and_get_or_none(): assert n is None +@mark_sync_test def test_first_and_first_or_none(): u = User(email="matt@test.com", age=24) assert u.save() @@ -105,6 +106,7 @@ def test_bare_init_without_save(): assert User().element_id is None +@mark_sync_test def test_save_to_model(): u = User(email="jim@test.com", age=3) assert u.save() @@ -113,24 +115,28 @@ def test_save_to_model(): assert u.age == 3 +@mark_sync_test def test_save_node_without_properties(): n = NodeWithoutProperty() assert n.save() assert n.element_id is not None +@mark_sync_test def test_unique(): - adb.install_labels(User) + db.install_labels(User) User(email="jim1@test.com", age=3).save() with raises(UniqueProperty): User(email="jim1@test.com", age=3).save() +@mark_sync_test def test_update_unique(): u = User(email="jimxx@test.com", age=3).save() u.save() # this shouldn't fail +@mark_sync_test def test_update(): user = User(email="jim2@test.com", age=3).save() assert user @@ -141,6 +147,7 @@ def test_update(): assert jim.email == "jim2000@test.com" +@mark_sync_test def test_save_through_magic_property(): user = User(email_alias="blah@test.com", age=8).save() assert user.email_alias == "blah@test.com" @@ -156,14 +163,15 @@ def test_save_through_magic_property(): assert user2 -class Customer2(AsyncStructuredNode): +class Customer2(StructuredNode): __label__ = "customers" email = StringProperty(unique_index=True, required=True) age = IntegerProperty(index=True) +@mark_sync_test def test_not_updated_on_unique_error(): - adb.install_labels(Customer2) + db.install_labels(Customer2) Customer2(email="jim@bob.com", age=7).save() test = Customer2(email="jim1@bob.com", age=2).save() test.email = "jim@bob.com" @@ -171,10 +179,11 @@ def test_not_updated_on_unique_error(): test.save() customers = Customer2.nodes.all() assert customers[0].email != customers[1].email - assert Customer2.nodes.get(email="jim@bob.com").age == 7 - assert Customer2.nodes.get(email="jim1@bob.com").age == 2 + assert (Customer2.nodes.get(email="jim@bob.com")).age == 7 + assert (Customer2.nodes.get(email="jim1@bob.com")).age == 2 +@mark_sync_test def test_label_not_inherited(): class Customer3(Customer2): address = StringProperty() @@ -190,6 +199,7 @@ class Customer3(Customer2): assert "Customer3" in c.labels() +@mark_sync_test def test_refresh(): c = Customer2(email="my@email.com", age=16).save() c.my_custom_prop = "value" @@ -209,7 +219,7 @@ def test_refresh(): assert c.age == 20 - if adb.database_version.startswith("4"): + if db.database_version.startswith("4"): c = Customer2.inflate(999) else: c = Customer2.inflate("4:xxxxxx:999") @@ -217,6 +227,7 @@ def test_refresh(): c.refresh() +@mark_sync_test def test_setting_value_to_none(): c = Customer2(email="alice@bob.com", age=42).save() assert c.age is not None @@ -228,8 +239,9 @@ def test_setting_value_to_none(): assert copy.age is None +@mark_sync_test def test_inheritance(): - class User(AsyncStructuredNode): + class User(StructuredNode): __abstract_node__ = True name = StringProperty(unique_index=True) @@ -247,11 +259,12 @@ def credit_account(self, amount): assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 assert len(jim.labels()) == 1 - assert jim.labels()[0] == "Shopper" + assert (jim.labels())[0] == "Shopper" +@mark_sync_test def test_inherited_optional_labels(): - class BaseOptional(AsyncStructuredNode): + class BaseOptional(StructuredNode): __optional_labels__ = ["Alive"] name = StringProperty(unique_index=True) @@ -274,6 +287,7 @@ def credit_account(self, amount): assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"} +@mark_sync_test def test_mixins(): class UserMixin: name = StringProperty(unique_index=True) @@ -286,7 +300,7 @@ def credit_account(self, amount): self.balance = self.balance + int(amount) self.save() - class Shopper2(AsyncStructuredNode, UserMixin, CreditMixin): + class Shopper2(StructuredNode, UserMixin, CreditMixin): pass jim = Shopper2(name="jimmy", balance=300).save() @@ -296,11 +310,12 @@ class Shopper2(AsyncStructuredNode, UserMixin, CreditMixin): assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 assert len(jim.labels()) == 1 - assert jim.labels()[0] == "Shopper2" + assert (jim.labels())[0] == "Shopper2" +@mark_sync_test def test_date_property(): - class DateTest(AsyncStructuredNode): + class DateTest(StructuredNode): birthdate = DateProperty() user = DateTest(birthdate=datetime.now()).save() @@ -310,36 +325,36 @@ def test_reserved_property_keys(): error_match = r".*is not allowed as it conflicts with neomodel internals.*" with raises(ValueError, match=error_match): - class ReservedPropertiesDeletedNode(AsyncStructuredNode): + class ReservedPropertiesDeletedNode(StructuredNode): deleted = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesIdNode(AsyncStructuredNode): + class ReservedPropertiesIdNode(StructuredNode): id = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesElementIdNode(AsyncStructuredNode): + class ReservedPropertiesElementIdNode(StructuredNode): element_id = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesIdRel(AsyncStructuredRel): + class ReservedPropertiesIdRel(StructuredRel): id = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesElementIdRel(AsyncStructuredRel): + class ReservedPropertiesElementIdRel(StructuredRel): element_id = StringProperty() error_match = r"Property names 'source' and 'target' are not allowed as they conflict with neomodel internals." with raises(ValueError, match=error_match): - class ReservedPropertiesSourceRel(AsyncStructuredRel): + class ReservedPropertiesSourceRel(StructuredRel): source = StringProperty() with raises(ValueError, match=error_match): - class ReservedPropertiesTargetRel(AsyncStructuredRel): + class ReservedPropertiesTargetRel(StructuredRel): target = StringProperty() diff --git a/test/test_multiprocessing.py b/test/sync/test_multiprocessing.py similarity index 57% rename from test/test_multiprocessing.py rename to test/sync/test_multiprocessing.py index 28f1422c..830c5d6a 100644 --- a/test/test_multiprocessing.py +++ b/test/sync/test_multiprocessing.py @@ -1,9 +1,11 @@ from multiprocessing.pool import ThreadPool as Pool +from test._async_compat import mark_sync_test -from neomodel import AsyncStructuredNode, StringProperty, adb +from neomodel import StructuredNode, StringProperty +from neomodel.sync_.core import db -class ThingyMaBob(AsyncStructuredNode): +class ThingyMaBob(StructuredNode): name = StringProperty(unique_index=True, required=True) @@ -13,9 +15,11 @@ def thing_create(name): return thing.name, name +@mark_sync_test def test_concurrency(): with Pool(5) as p: results = p.map(thing_create, range(50)) - for returned, sent in results: + for to_unpack in results: + returned, sent = to_unpack assert returned == sent - adb.close_connection() + db.close_connection() diff --git a/test/test_paths.py b/test/sync/test_paths.py similarity index 75% rename from test/test_paths.py rename to test/sync/test_paths.py index 9c0000f8..13201439 100644 --- a/test/test_paths.py +++ b/test/sync/test_paths.py @@ -1,16 +1,17 @@ +from test._async_compat import mark_sync_test from neomodel import ( - AsyncNeomodelPath, - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, + NeomodelPath, + RelationshipTo, + StructuredNode, + StructuredRel, IntegerProperty, StringProperty, UniqueIdProperty, - adb, ) +from neomodel.sync_.core import db -class PersonLivesInCity(AsyncStructuredRel): +class PersonLivesInCity(StructuredRel): """ Relationship with data that will be instantiated as "stand-alone" """ @@ -18,24 +19,25 @@ class PersonLivesInCity(AsyncStructuredRel): some_num = IntegerProperty(index=True, default=12) -class CountryOfOrigin(AsyncStructuredNode): +class CountryOfOrigin(StructuredNode): code = StringProperty(unique_index=True, required=True) -class CityOfResidence(AsyncStructuredNode): +class CityOfResidence(StructuredNode): name = StringProperty(required=True) - country = AsyncRelationshipTo(CountryOfOrigin, "FROM_COUNTRY") + country = RelationshipTo(CountryOfOrigin, "FROM_COUNTRY") -class PersonOfInterest(AsyncStructuredNode): +class PersonOfInterest(StructuredNode): uid = UniqueIdProperty() name = StringProperty(unique_index=True) age = IntegerProperty(index=True, default=0) - country = AsyncRelationshipTo(CountryOfOrigin, "IS_FROM") - city = AsyncRelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) + country = RelationshipTo(CountryOfOrigin, "IS_FROM") + city = RelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) +@mark_sync_test def test_path_instantiation(): """ Neo4j driver paths should be instantiated as neomodel paths, with all of @@ -66,7 +68,7 @@ def test_path_instantiation(): p4.city.connect(ct2) # Retrieve a single path - q = adb.cypher_query( + q = db.cypher_query( "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects=True, ) @@ -75,13 +77,13 @@ def test_path_instantiation(): path_nodes = path_object.nodes path_rels = path_object.relationships - assert type(path_object) is AsyncNeomodelPath + assert type(path_object) is NeomodelPath assert type(path_nodes[0]) is CityOfResidence assert type(path_nodes[1]) is PersonOfInterest assert type(path_nodes[2]) is CountryOfOrigin assert type(path_rels[0]) is PersonLivesInCity - assert type(path_rels[1]) is AsyncStructuredRel + assert type(path_rels[1]) is StructuredRel c1.delete() c2.delete() diff --git a/test/test_properties.py b/test/sync/test_properties.py similarity index 89% rename from test/test_properties.py rename to test/sync/test_properties.py index d03f3284..2d2b3873 100644 --- a/test/test_properties.py +++ b/test/sync/test_properties.py @@ -1,9 +1,11 @@ from datetime import date, datetime +from test._async_compat import mark_sync_test from pytest import mark, raises from pytz import timezone -from neomodel import AsyncStructuredNode, adb, config +from neomodel import StructuredNode +from neomodel.sync_.core import db from neomodel.exceptions import ( DeflateError, InflateError, @@ -25,8 +27,6 @@ ) from neomodel.util import _get_node_properties -config.AUTO_INSTALL_LABELS = True - class FooBar: pass @@ -60,8 +60,9 @@ def test_string_property_exceeds_max_length(): ), "StringProperty max_length test passed but values do not match." +@mark_sync_test def test_string_property_w_choice(): - class TestChoices(AsyncStructuredNode): + class TestChoices(StructuredNode): SEXES = {"F": "Female", "M": "Male", "O": "Other"} sex = StringProperty(required=True, choices=SEXES) @@ -185,8 +186,9 @@ def test_json(): assert prop.inflate('{"test": [1, 2, 3]}') == value +@mark_sync_test def test_default_value(): - class DefaultTestValue(AsyncStructuredNode): + class DefaultTestValue(StructuredNode): name_xx = StringProperty(default="jim", index=True) a = DefaultTestValue() @@ -194,17 +196,19 @@ class DefaultTestValue(AsyncStructuredNode): a.save() +@mark_sync_test def test_default_value_callable(): def uid_generator(): return "xx" - class DefaultTestValueTwo(AsyncStructuredNode): + class DefaultTestValueTwo(StructuredNode): uid = StringProperty(default=uid_generator, index=True) a = DefaultTestValueTwo().save() assert a.uid == "xx" +@mark_sync_test def test_default_value_callable_type(): # check our object gets converted to str without serializing and reload def factory(): @@ -214,7 +218,7 @@ def __str__(self): return Foo() - class DefaultTestValueThree(AsyncStructuredNode): + class DefaultTestValueThree(StructuredNode): uid = StringProperty(default=factory, index=True) x = DefaultTestValueThree() @@ -225,8 +229,9 @@ class DefaultTestValueThree(AsyncStructuredNode): assert x.uid == "123" +@mark_sync_test def test_independent_property_name(): - class TestDBNamePropertyNode(AsyncStructuredNode): + class TestDBNamePropertyNode(StructuredNode): name_ = StringProperty(db_property="name") x = TestDBNamePropertyNode() @@ -234,7 +239,7 @@ class TestDBNamePropertyNode(AsyncStructuredNode): x.save() # check database property name on low level - results, meta = adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") + results, meta = db.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["name"] == "jim" @@ -242,24 +247,27 @@ class TestDBNamePropertyNode(AsyncStructuredNode): assert not "name_" in node_properties assert not hasattr(x, "name") assert hasattr(x, "name_") - assert TestDBNamePropertyNode.nodes.filter(name_="jim").all()[0].name_ == x.name_ - assert TestDBNamePropertyNode.nodes.get(name_="jim").name_ == x.name_ + assert (TestDBNamePropertyNode.nodes.filter(name_="jim").all())[ + 0 + ].name_ == x.name_ + assert (TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ x.delete() +@mark_sync_test def test_independent_property_name_get_or_create(): - class TestNode(AsyncStructuredNode): + class TestNode(StructuredNode): uid = UniqueIdProperty() name_ = StringProperty(db_property="name", required=True) # create the node TestNode.get_or_create({"uid": 123, "name_": "jim"}) # test that the node is retrieved correctly - x = TestNode.get_or_create({"uid": 123, "name_": "jim"})[0] + x = (TestNode.get_or_create({"uid": 123, "name_": "jim"}))[0] # check database property name on low level - results, meta = adb.cypher_query("MATCH (n:TestNode) RETURN n") + results, _ = db.cypher_query("MATCH (n:TestNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties @@ -331,6 +339,7 @@ def test_email_property(): prop.deflate("foo@example") +@mark_sync_test def test_uid_property(): prop = UniqueIdProperty() prop.name = "uid" @@ -338,19 +347,20 @@ def test_uid_property(): myuid = prop.default_value() assert len(myuid) - class CheckMyId(AsyncStructuredNode): + class CheckMyId(StructuredNode): uid = UniqueIdProperty() cmid = CheckMyId().save() assert len(cmid.uid) -class ArrayProps(AsyncStructuredNode): +class ArrayProps(StructuredNode): uid = StringProperty(unique_index=True) untyped_arr = ArrayProperty() typed_arr = ArrayProperty(IntegerProperty()) +@mark_sync_test def test_array_properties(): # untyped ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save() @@ -377,8 +387,9 @@ def test_illegal_array_base_prop_raises(): ArrayProperty(StringProperty(index=True)) +@mark_sync_test def test_indexed_array(): - class IndexArray(AsyncStructuredNode): + class IndexArray(StructuredNode): ai = ArrayProperty(unique_index=True) b = IndexArray(ai=[1, 2]).save() @@ -386,8 +397,9 @@ class IndexArray(AsyncStructuredNode): assert b.element_id == c.element_id +@mark_sync_test def test_unique_index_prop_not_required(): - class ConstrainedTestNode(AsyncStructuredNode): + class ConstrainedTestNode(StructuredNode): required_property = StringProperty(required=True) unique_property = StringProperty(unique_index=True) unique_required_property = StringProperty(unique_index=True, required=True) @@ -406,7 +418,7 @@ class ConstrainedTestNode(AsyncStructuredNode): x.save() # check database property name on low level - results, meta = adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") + results, meta = db.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["unique_required_property"] == "unique and required" @@ -414,10 +426,12 @@ class ConstrainedTestNode(AsyncStructuredNode): x.delete() +@mark_sync_test def test_unique_index_prop_enforced(): - class UniqueNullableNameNode(AsyncStructuredNode): + class UniqueNullableNameNode(StructuredNode): name = StringProperty(unique_index=True) + db.install_labels(UniqueNullableNameNode) # Nameless x = UniqueNullableNameNode() x.save() @@ -432,7 +446,7 @@ class UniqueNullableNameNode(AsyncStructuredNode): a.save() # Check nodes are in database - results, meta = adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") + results, _ = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") assert len(results) == 3 # Delete nodes afterwards diff --git a/test/test_relationship_models.py b/test/sync/test_relationship_models.py similarity index 86% rename from test/test_relationship_models.py rename to test/sync/test_relationship_models.py index 89b50c53..f1ecbe7e 100644 --- a/test/test_relationship_models.py +++ b/test/sync/test_relationship_models.py @@ -1,13 +1,14 @@ from datetime import datetime +from test._async_compat import mark_sync_test import pytz from pytest import raises from neomodel import ( - AsyncRelationship, - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, + Relationship, + RelationshipTo, + StructuredNode, + StructuredRel, DateTimeProperty, DeflateError, StringProperty, @@ -16,7 +17,7 @@ HOOKS_CALLED = {"pre_save": 0, "post_save": 0} -class FriendRel(AsyncStructuredRel): +class FriendRel(StructuredRel): since = DateTimeProperty(default=lambda: datetime.now(pytz.utc)) @@ -30,17 +31,18 @@ def post_save(self): HOOKS_CALLED["post_save"] += 1 -class Badger(AsyncStructuredNode): +class Badger(StructuredNode): name = StringProperty(unique_index=True) - friend = AsyncRelationship("Badger", "FRIEND", model=FriendRel) - hates = AsyncRelationshipTo("Stoat", "HATES", model=HatesRel) + friend = Relationship("Badger", "FRIEND", model=FriendRel) + hates = RelationshipTo("Stoat", "HATES", model=HatesRel) -class Stoat(AsyncStructuredNode): +class Stoat(StructuredNode): name = StringProperty(unique_index=True) - hates = AsyncRelationshipTo("Badger", "HATES", model=HatesRel) + hates = RelationshipTo("Badger", "HATES", model=HatesRel) +@mark_sync_test def test_either_connect_with_rel_model(): paul = Badger(name="Paul").save() tom = Badger(name="Tom").save() @@ -63,11 +65,14 @@ def test_either_connect_with_rel_model(): assert tom.name == "Paul" +@mark_sync_test def test_direction_connect_with_rel_model(): paul = Badger(name="Paul the badger").save() ian = Stoat(name="Ian the stoat").save() - rel = ian.hates.connect(paul, {"reason": "thinks paul should bath more often"}) + rel = ian.hates.connect( + paul, {"reason": "thinks paul should bath more often"} + ) assert isinstance(rel.since, datetime) assert isinstance(rel, FriendRel) assert rel.reason.startswith("thinks") @@ -103,6 +108,7 @@ def test_direction_connect_with_rel_model(): ) +@mark_sync_test def test_traversal_where_clause(): phill = Badger(name="Phill the badger").save() tim = Badger(name="Tim the badger").save() @@ -113,9 +119,10 @@ def test_traversal_where_clause(): rel2 = tim.friend.connect(phill) assert rel2.since > now friends = tim.friend.match(since__gt=now) - assert len(friends) == 1 + assert len(friends.all()) == 1 +@mark_sync_test def test_multiple_rels_exist_issue_223(): # check a badger can dislike a stoat for multiple reasons phill = Badger(name="Phill").save() @@ -130,6 +137,7 @@ def test_multiple_rels_exist_issue_223(): assert ian_a.element_id == ian_b.element_id +@mark_sync_test def test_retrieve_all_rels(): tom = Badger(name="tom").save() ian = Stoat(name="ian").save() @@ -143,6 +151,7 @@ def test_retrieve_all_rels(): assert rels[1].element_id in [rel_a.element_id, rel_b.element_id] +@mark_sync_test def test_save_hook_on_rel_model(): HOOKS_CALLED["pre_save"] = 0 HOOKS_CALLED["post_save"] = 0 diff --git a/test/test_relationships.py b/test/sync/test_relationships.py similarity index 79% rename from test/test_relationships.py rename to test/sync/test_relationships.py index fa6ff01d..462c966e 100644 --- a/test/test_relationships.py +++ b/test/sync/test_relationships.py @@ -1,24 +1,25 @@ from pytest import raises +from test._async_compat import mark_sync_test from neomodel import ( - AsyncOne, - AsyncRelationship, - AsyncRelationshipFrom, - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, + One, + Relationship, + RelationshipFrom, + RelationshipTo, + StructuredNode, + StructuredRel, IntegerProperty, Q, StringProperty, ) -from neomodel.async_.core import adb +from neomodel.sync_.core import db -class PersonWithRels(AsyncStructuredNode): +class PersonWithRels(StructuredNode): name = StringProperty(unique_index=True) age = IntegerProperty(index=True) - is_from = AsyncRelationshipTo("Country", "IS_FROM") - knows = AsyncRelationship("PersonWithRels", "KNOWS") + is_from = RelationshipTo("Country", "IS_FROM") + knows = Relationship("PersonWithRels", "KNOWS") @property def special_name(self): @@ -28,10 +29,10 @@ def special_power(self): return "I have no powers" -class Country(AsyncStructuredNode): +class Country(StructuredNode): code = StringProperty(unique_index=True) - inhabitant = AsyncRelationshipFrom(PersonWithRels, "IS_FROM") - president = AsyncRelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne) + inhabitant = RelationshipFrom(PersonWithRels, "IS_FROM") + president = RelationshipTo(PersonWithRels, "PRESIDENT", cardinality=One) class SuperHero(PersonWithRels): @@ -41,6 +42,7 @@ def special_power(self): return "I have powers" +@mark_sync_test def test_actions_on_deleted_node(): u = PersonWithRels(name="Jim2", age=3).save() u.delete() @@ -54,6 +56,7 @@ def test_actions_on_deleted_node(): u.save() +@mark_sync_test def test_bidirectional_relationships(): u = PersonWithRels(name="Jim", age=3).save() assert u @@ -61,26 +64,27 @@ def test_bidirectional_relationships(): de = Country(code="DE").save() assert de - assert not u.is_from + assert not u.is_from.all() assert u.is_from.__class__.__name__ == "ZeroOrMore" u.is_from.connect(de) - assert len(u.is_from) == 1 + assert len(u.is_from.all()) == 1 assert u.is_from.is_connected(de) - b = u.is_from.all()[0] + b = (u.is_from.all())[0] assert b.__class__.__name__ == "Country" assert b.code == "DE" - s = b.inhabitant.all()[0] + s = (b.inhabitant.all())[0] assert s.name == "Jim" u.is_from.disconnect(b) assert not u.is_from.is_connected(b) +@mark_sync_test def test_either_direction_connect(): rey = PersonWithRels(name="Rey", age=3).save() sakis = PersonWithRels(name="Sakis", age=3).save() @@ -92,19 +96,21 @@ def test_either_direction_connect(): result, _ = sakis.cypher( f"""MATCH (us), (them) - WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(them)=$them + WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", {"them": rey.element_id}, ) assert int(result[0][0]) == 1 rel = rey.knows.relationship(sakis) - assert isinstance(rel, AsyncStructuredRel) + assert isinstance(rel, StructuredRel) rels = rey.knows.all_relationships(sakis) - assert isinstance(rels[0], AsyncStructuredRel) + assert isinstance(rels[0], StructuredRel) +# TODO : Make async-independent test to test .filter and not .filter.all() ? +@mark_sync_test def test_search_and_filter_and_exclude(): fred = PersonWithRels(name="Fred", age=13).save() zz = Country(code="ZZ").save() @@ -129,6 +135,7 @@ def test_search_and_filter_and_exclude(): assert len(result) == 3 +@mark_sync_test def test_custom_methods(): u = PersonWithRels(name="Joe90", age=13).save() assert u.special_power() == "I have no powers" @@ -137,6 +144,7 @@ def test_custom_methods(): assert u.special_name == "Joe91" +@mark_sync_test def test_valid_reconnection(): p = PersonWithRels(name="ElPresidente", age=93).save() assert p @@ -159,6 +167,7 @@ def test_valid_reconnection(): assert c.president.is_connected(pp) +@mark_sync_test def test_valid_replace(): brady = PersonWithRels(name="Tom Brady", age=40).save() assert brady @@ -174,17 +183,18 @@ def test_valid_replace(): brady.knows.connect(gronk) brady.knows.connect(colbert) - assert len(brady.knows) == 2 + assert len(brady.knows.all()) == 2 assert brady.knows.is_connected(gronk) assert brady.knows.is_connected(colbert) brady.knows.replace(hanks) - assert len(brady.knows) == 1 + assert len(brady.knows.all()) == 1 assert brady.knows.is_connected(hanks) assert not brady.knows.is_connected(gronk) assert not brady.knows.is_connected(colbert) +@mark_sync_test def test_props_relationship(): u = PersonWithRels(name="Mar", age=20).save() assert u diff --git a/test/sync/test_relative_relationships.py b/test/sync/test_relative_relationships.py new file mode 100644 index 00000000..07a2e6ca --- /dev/null +++ b/test/sync/test_relative_relationships.py @@ -0,0 +1,23 @@ +from neomodel import RelationshipTo, StructuredNode, StringProperty +from test.sync_.test_relationships import Country +from test._async_compat import mark_sync_test + + +class Cat(StructuredNode): + name = StringProperty() + # Relationship is defined using a relative class path + is_from = RelationshipTo(".test_relationships.Country", "IS_FROM") + + +@mark_sync_test +def test_relative_relationship(): + a = Cat(name="snufkin").save() + assert a + + c = Country(code="MG").save() + assert c + + # connecting an instance of the class defined above + # the next statement will fail if there's a type mismatch + a.is_from.connect(c) + assert a.is_from.is_connected(c) diff --git a/test/test_scripts.py b/test/test_scripts.py index 77dee66e..9f0690d9 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -3,26 +3,24 @@ import pytest from neomodel import ( - AsyncRelationshipTo, - AsyncStructuredNode, - AsyncStructuredRel, + RelationshipTo, + StructuredNode, + StructuredRel, StringProperty, config, ) -from neomodel.async_.core import adb +from neomodel.sync_.core import db -class ScriptsTestRel(AsyncStructuredRel): - some_unique_property = StringProperty( - unique_index=adb.version_is_higher_than("5.7") - ) +class ScriptsTestRel(StructuredRel): + some_unique_property = StringProperty(unique_index=db.version_is_higher_than("5.7")) some_index_property = StringProperty(index=True) -class ScriptsTestNode(AsyncStructuredNode): +class ScriptsTestNode(StructuredNode): personal_id = StringProperty(unique_index=True) name = StringProperty(index=True) - rel = AsyncRelationshipTo("ScriptsTestNode", "REL", model=ScriptsTestRel) + rel = RelationshipTo("ScriptsTestNode", "REL", model=ScriptsTestRel) def test_neomodel_install_labels(): @@ -36,26 +34,26 @@ def test_neomodel_install_labels(): assert result.returncode == 0 result = subprocess.run( - ["neomodel_install_labels", "test/test_scripts.py", "--db", adb.url], + ["neomodel_install_labels", "test/test_scripts.py", "--db", db.url], capture_output=True, text=True, check=False, ) assert result.returncode == 0 assert "Setting up indexes and constraints" in result.stdout - constraints = adb.list_constraints() + constraints = db.list_constraints() parsed_constraints = [ (element["type"], element["labelsOrTypes"], element["properties"]) for element in constraints ] assert ("UNIQUENESS", ["ScriptsTestNode"], ["personal_id"]) in parsed_constraints - if adb.version_is_higher_than("5.7"): + if db.version_is_higher_than("5.7"): assert ( "RELATIONSHIP_UNIQUENESS", ["REL"], ["some_unique_property"], ) in parsed_constraints - indexes = adb.list_indexes() + indexes = db.list_indexes() parsed_indexes = [ (element["labelsOrTypes"], element["properties"]) for element in indexes ] @@ -83,8 +81,8 @@ def test_neomodel_remove_labels(): "Dropping unique constraint and index on label ScriptsTestNode" in result.stdout ) assert result.returncode == 0 - constraints = adb.list_constraints() - indexes = adb.list_indexes(exclude_token_lookup=True) + constraints = db.list_constraints() + indexes = db.list_indexes(exclude_token_lookup=True) assert len(constraints) == 0 assert len(indexes) == 0 @@ -108,9 +106,9 @@ def test_neomodel_inspect_database(script_flavour): assert "usage: neomodel_inspect_database" in result.stdout assert result.returncode == 0 - adb.clear_neo4j_database() - adb.install_labels(ScriptsTestNode) - adb.install_labels(ScriptsTestRel) + db.clear_neo4j_database() + db.install_labels(ScriptsTestNode) + db.install_labels(ScriptsTestRel) # Create a few nodes and a rel, with indexes and constraints node1 = ScriptsTestNode(personal_id="1", name="test").save() @@ -119,7 +117,7 @@ def test_neomodel_inspect_database(script_flavour): # Create a node with all the parsable property types # Also create a node with no properties - adb.cypher_query( + db.cypher_query( """ CREATE (:EveryPropertyTypeNode { string_property: "Hello World", @@ -155,7 +153,7 @@ def test_neomodel_inspect_database(script_flavour): # Check that all the expected lines are here file_path = ( f"test/data/neomodel_inspect_database_output{script_flavour}.txt" - if adb.version_is_higher_than("5.7") + if db.version_is_higher_than("5.7") else f"test/data/neomodel_inspect_database_output_pre_5_7{script_flavour}.txt" ) with open(file_path, "r") as f: diff --git a/test/test_transactions.py b/test/test_transactions.py deleted file mode 100644 index 83623821..00000000 --- a/test/test_transactions.py +++ /dev/null @@ -1,182 +0,0 @@ -import pytest -from neo4j.api import Bookmarks -from neo4j.exceptions import ClientError, TransactionError -from pytest import raises - -from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty -from neomodel.async_.core import adb - - -class APerson(AsyncStructuredNode): - name = StringProperty(unique_index=True) - - -def test_rollback_and_commit_transaction(): - for p in APerson.nodes: - p.delete() - - APerson(name="Roger").save() - - adb.begin() - APerson(name="Terry S").save() - adb.rollback() - - assert len(APerson.nodes) == 1 - - adb.begin() - APerson(name="Terry S").save() - adb.commit() - - assert len(APerson.nodes) == 2 - - -@adb.transaction -def in_a_tx(*names): - for n in names: - APerson(name=n).save() - - -def test_transaction_decorator(): - adb.install_labels(APerson) - for p in APerson.nodes: - p.delete() - - # should work - in_a_tx("Roger") - assert True - - # should bail but raise correct error - with raises(UniqueProperty): - in_a_tx("Jim", "Roger") - - assert "Jim" not in [p.name for p in APerson.nodes] - - -def test_transaction_as_a_context(): - with adb.transaction: - APerson(name="Tim").save() - - assert APerson.nodes.filter(name="Tim") - - with raises(UniqueProperty): - with adb.transaction: - APerson(name="Tim").save() - - -def test_query_inside_transaction(): - for p in APerson.nodes: - p.delete() - - with adb.transaction: - APerson(name="Alice").save() - APerson(name="Bob").save() - - assert len([p.name for p in APerson.nodes]) == 2 - - -def test_read_transaction(): - APerson(name="Johnny").save() - - with adb.read_transaction: - people = APerson.nodes.all() - assert people - - with raises(TransactionError): - with adb.read_transaction: - with raises(ClientError) as e: - APerson(name="Gina").save() - assert e.value.code == "Neo.ClientError.Statement.AccessMode" - - -def test_write_transaction(): - with adb.write_transaction: - APerson(name="Amelia").save() - - amelia = APerson.nodes.get(name="Amelia") - assert amelia - - -def double_transaction(): - adb.begin() - with raises(SystemError, match=r"Transaction in progress"): - adb.begin() - - adb.rollback() - - -@adb.transaction.with_bookmark -def in_a_tx(*names): - for n in names: - APerson(name=n).save() - - -def test_bookmark_transaction_decorator(): - for p in APerson.nodes: - p.delete() - - # should work - result, bookmarks = in_a_tx("Ruth", bookmarks=None) - assert result is None - assert isinstance(bookmarks, Bookmarks) - - # should bail but raise correct error - with raises(UniqueProperty): - in_a_tx("Jane", "Ruth") - - assert "Jane" not in [p.name for p in APerson.nodes] - - -def test_bookmark_transaction_as_a_context(): - with adb.transaction as transaction: - APerson(name="Tanya").save() - assert isinstance(transaction.last_bookmark, Bookmarks) - - assert APerson.nodes.filter(name="Tanya") - - with raises(UniqueProperty): - with adb.transaction as transaction: - APerson(name="Tanya").save() - assert not hasattr(transaction, "last_bookmark") - - -@pytest.fixture -def spy_on_db_begin(monkeypatch): - spy_calls = [] - original_begin = adb.begin - - def begin_spy(*args, **kwargs): - spy_calls.append((args, kwargs)) - return original_begin(*args, **kwargs) - - monkeypatch.setattr(adb, "begin", begin_spy) - return spy_calls - - -def test_bookmark_passed_in_to_context(spy_on_db_begin): - transaction = adb.transaction - with transaction: - pass - - assert spy_on_db_begin[-1] == ((), {"access_mode": None, "bookmarks": None}) - last_bookmark = transaction.last_bookmark - - transaction.bookmarks = last_bookmark - with transaction: - pass - assert spy_on_db_begin[-1] == ( - (), - {"access_mode": None, "bookmarks": last_bookmark}, - ) - - -def test_query_inside_bookmark_transaction(): - for p in APerson.nodes: - p.delete() - - with adb.transaction as transaction: - APerson(name="Alice").save() - APerson(name="Bob").save() - - assert len([p.name for p in APerson.nodes]) == 2 - - assert isinstance(transaction.last_bookmark, Bookmarks) From 45120eabb2cd54c9664572ea004909748f12c08e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 11:50:54 +0100 Subject: [PATCH 27/73] Rename sync test folder --- bin/make-unasync | 2 +- test/{sync => sync_}/__init__.py | 0 test/{sync => sync_}/conftest.py | 0 test/{sync => sync_}/test_alias.py | 0 test/{sync => sync_}/test_batch.py | 0 test/{sync => sync_}/test_cardinality.py | 0 test/{sync => sync_}/test_connection.py | 0 test/{sync => sync_}/test_cypher.py | 0 test/{sync => sync_}/test_database_management.py | 0 test/{sync => sync_}/test_dbms_awareness.py | 0 test/{sync => sync_}/test_driver_options.py | 0 test/{sync => sync_}/test_exceptions.py | 0 test/{sync => sync_}/test_hooks.py | 0 test/{sync => sync_}/test_indexing.py | 0 test/{sync => sync_}/test_issue112.py | 0 test/{sync => sync_}/test_issue283.py | 0 test/{sync => sync_}/test_issue600.py | 0 test/{sync => sync_}/test_label_drop.py | 0 test/{sync => sync_}/test_label_install.py | 0 test/{sync => sync_}/test_match_api.py | 0 test/{sync => sync_}/test_migration_neo4j_5.py | 0 test/{sync => sync_}/test_models.py | 0 test/{sync => sync_}/test_multiprocessing.py | 0 test/{sync => sync_}/test_paths.py | 0 test/{sync => sync_}/test_properties.py | 0 test/{sync => sync_}/test_relationship_models.py | 0 test/{sync => sync_}/test_relationships.py | 0 test/{sync => sync_}/test_relative_relationships.py | 0 28 files changed, 1 insertion(+), 1 deletion(-) rename test/{sync => sync_}/__init__.py (100%) rename test/{sync => sync_}/conftest.py (100%) rename test/{sync => sync_}/test_alias.py (100%) rename test/{sync => sync_}/test_batch.py (100%) rename test/{sync => sync_}/test_cardinality.py (100%) rename test/{sync => sync_}/test_connection.py (100%) rename test/{sync => sync_}/test_cypher.py (100%) rename test/{sync => sync_}/test_database_management.py (100%) rename test/{sync => sync_}/test_dbms_awareness.py (100%) rename test/{sync => sync_}/test_driver_options.py (100%) rename test/{sync => sync_}/test_exceptions.py (100%) rename test/{sync => sync_}/test_hooks.py (100%) rename test/{sync => sync_}/test_indexing.py (100%) rename test/{sync => sync_}/test_issue112.py (100%) rename test/{sync => sync_}/test_issue283.py (100%) rename test/{sync => sync_}/test_issue600.py (100%) rename test/{sync => sync_}/test_label_drop.py (100%) rename test/{sync => sync_}/test_label_install.py (100%) rename test/{sync => sync_}/test_match_api.py (100%) rename test/{sync => sync_}/test_migration_neo4j_5.py (100%) rename test/{sync => sync_}/test_models.py (100%) rename test/{sync => sync_}/test_multiprocessing.py (100%) rename test/{sync => sync_}/test_paths.py (100%) rename test/{sync => sync_}/test_properties.py (100%) rename test/{sync => sync_}/test_relationship_models.py (100%) rename test/{sync => sync_}/test_relationships.py (100%) rename test/{sync => sync_}/test_relative_relationships.py (100%) diff --git a/bin/make-unasync b/bin/make-unasync index 7b796bfa..5acdf5e7 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -16,7 +16,7 @@ SYNC_DIR = ROOT_DIR / "neomodel" / "sync_" ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" -SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync" +SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_" UNASYNC_SUFFIX = ".unasync" PY_FILE_EXTENSIONS = {".py"} diff --git a/test/sync/__init__.py b/test/sync_/__init__.py similarity index 100% rename from test/sync/__init__.py rename to test/sync_/__init__.py diff --git a/test/sync/conftest.py b/test/sync_/conftest.py similarity index 100% rename from test/sync/conftest.py rename to test/sync_/conftest.py diff --git a/test/sync/test_alias.py b/test/sync_/test_alias.py similarity index 100% rename from test/sync/test_alias.py rename to test/sync_/test_alias.py diff --git a/test/sync/test_batch.py b/test/sync_/test_batch.py similarity index 100% rename from test/sync/test_batch.py rename to test/sync_/test_batch.py diff --git a/test/sync/test_cardinality.py b/test/sync_/test_cardinality.py similarity index 100% rename from test/sync/test_cardinality.py rename to test/sync_/test_cardinality.py diff --git a/test/sync/test_connection.py b/test/sync_/test_connection.py similarity index 100% rename from test/sync/test_connection.py rename to test/sync_/test_connection.py diff --git a/test/sync/test_cypher.py b/test/sync_/test_cypher.py similarity index 100% rename from test/sync/test_cypher.py rename to test/sync_/test_cypher.py diff --git a/test/sync/test_database_management.py b/test/sync_/test_database_management.py similarity index 100% rename from test/sync/test_database_management.py rename to test/sync_/test_database_management.py diff --git a/test/sync/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py similarity index 100% rename from test/sync/test_dbms_awareness.py rename to test/sync_/test_dbms_awareness.py diff --git a/test/sync/test_driver_options.py b/test/sync_/test_driver_options.py similarity index 100% rename from test/sync/test_driver_options.py rename to test/sync_/test_driver_options.py diff --git a/test/sync/test_exceptions.py b/test/sync_/test_exceptions.py similarity index 100% rename from test/sync/test_exceptions.py rename to test/sync_/test_exceptions.py diff --git a/test/sync/test_hooks.py b/test/sync_/test_hooks.py similarity index 100% rename from test/sync/test_hooks.py rename to test/sync_/test_hooks.py diff --git a/test/sync/test_indexing.py b/test/sync_/test_indexing.py similarity index 100% rename from test/sync/test_indexing.py rename to test/sync_/test_indexing.py diff --git a/test/sync/test_issue112.py b/test/sync_/test_issue112.py similarity index 100% rename from test/sync/test_issue112.py rename to test/sync_/test_issue112.py diff --git a/test/sync/test_issue283.py b/test/sync_/test_issue283.py similarity index 100% rename from test/sync/test_issue283.py rename to test/sync_/test_issue283.py diff --git a/test/sync/test_issue600.py b/test/sync_/test_issue600.py similarity index 100% rename from test/sync/test_issue600.py rename to test/sync_/test_issue600.py diff --git a/test/sync/test_label_drop.py b/test/sync_/test_label_drop.py similarity index 100% rename from test/sync/test_label_drop.py rename to test/sync_/test_label_drop.py diff --git a/test/sync/test_label_install.py b/test/sync_/test_label_install.py similarity index 100% rename from test/sync/test_label_install.py rename to test/sync_/test_label_install.py diff --git a/test/sync/test_match_api.py b/test/sync_/test_match_api.py similarity index 100% rename from test/sync/test_match_api.py rename to test/sync_/test_match_api.py diff --git a/test/sync/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py similarity index 100% rename from test/sync/test_migration_neo4j_5.py rename to test/sync_/test_migration_neo4j_5.py diff --git a/test/sync/test_models.py b/test/sync_/test_models.py similarity index 100% rename from test/sync/test_models.py rename to test/sync_/test_models.py diff --git a/test/sync/test_multiprocessing.py b/test/sync_/test_multiprocessing.py similarity index 100% rename from test/sync/test_multiprocessing.py rename to test/sync_/test_multiprocessing.py diff --git a/test/sync/test_paths.py b/test/sync_/test_paths.py similarity index 100% rename from test/sync/test_paths.py rename to test/sync_/test_paths.py diff --git a/test/sync/test_properties.py b/test/sync_/test_properties.py similarity index 100% rename from test/sync/test_properties.py rename to test/sync_/test_properties.py diff --git a/test/sync/test_relationship_models.py b/test/sync_/test_relationship_models.py similarity index 100% rename from test/sync/test_relationship_models.py rename to test/sync_/test_relationship_models.py diff --git a/test/sync/test_relationships.py b/test/sync_/test_relationships.py similarity index 100% rename from test/sync/test_relationships.py rename to test/sync_/test_relationships.py diff --git a/test/sync/test_relative_relationships.py b/test/sync_/test_relative_relationships.py similarity index 100% rename from test/sync/test_relative_relationships.py rename to test/sync_/test_relative_relationships.py From f3a64291e92ef73211ed9266f7ed033593acb0a5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 13:30:09 +0100 Subject: [PATCH 28/73] More tests and isort tests --- test/async_/test_cardinality.py | 4 +- test/async_/test_connection.py | 4 +- test/async_/test_database_management.py | 3 +- test/async_/test_dbms_awareness.py | 3 +- test/async_/test_driver_options.py | 3 +- test/async_/test_hooks.py | 1 + test/async_/test_indexing.py | 3 +- test/async_/test_issue112.py | 1 + test/async_/test_issue283.py | 12 +- test/async_/test_issue600.py | 3 +- test/async_/test_label_drop.py | 3 +- test/async_/test_label_install.py | 3 +- test/async_/test_migration_neo4j_5.py | 3 +- test/async_/test_models.py | 2 +- test/async_/test_paths.py | 1 + test/async_/test_properties.py | 2 +- test/async_/test_relationships.py | 3 +- test/async_/test_relative_relationships.py | 5 +- test/async_/test_transactions.py | 196 +++++++++++++++++++++ test/sync_/test_alias.py | 2 +- test/sync_/test_batch.py | 12 +- test/sync_/test_cardinality.py | 20 +-- test/sync_/test_connection.py | 14 +- test/sync_/test_cypher.py | 14 +- test/sync_/test_database_management.py | 7 +- test/sync_/test_dbms_awareness.py | 3 +- test/sync_/test_driver_options.py | 5 +- test/sync_/test_exceptions.py | 2 +- test/sync_/test_hooks.py | 3 +- test/sync_/test_indexing.py | 12 +- test/sync_/test_issue112.py | 1 + test/sync_/test_issue283.py | 122 +++++-------- test/sync_/test_issue600.py | 3 +- test/sync_/test_label_drop.py | 5 +- test/sync_/test_label_install.py | 15 +- test/sync_/test_match_api.py | 27 +-- test/sync_/test_migration_neo4j_5.py | 7 +- test/sync_/test_models.py | 8 +- test/sync_/test_multiprocessing.py | 2 +- test/sync_/test_paths.py | 5 +- test/sync_/test_properties.py | 8 +- test/sync_/test_relationship_models.py | 10 +- test/sync_/test_relationships.py | 9 +- test/sync_/test_relative_relationships.py | 5 +- test/sync_/test_transactions.py | 196 +++++++++++++++++++++ test/test_scripts.py | 2 +- 46 files changed, 555 insertions(+), 219 deletions(-) create mode 100644 test/async_/test_transactions.py create mode 100644 test/sync_/test_transactions.py diff --git a/test/async_/test_cardinality.py b/test/async_/test_cardinality.py index bafe919d..e72fa912 100644 --- a/test/async_/test_cardinality.py +++ b/test/async_/test_cardinality.py @@ -1,4 +1,5 @@ from test._async_compat import mark_async_test + from pytest import raises from neomodel import ( @@ -6,14 +7,13 @@ AsyncOneOrMore, AsyncRelationshipTo, AsyncStructuredNode, + AsyncZeroOrMore, AsyncZeroOrOne, AttemptedCardinalityViolation, CardinalityViolation, IntegerProperty, StringProperty, - AsyncZeroOrMore, ) - from neomodel.async_.core import adb diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py index b82ca0f4..94d9b2dc 100644 --- a/test/async_/test_connection.py +++ b/test/async_/test_connection.py @@ -1,9 +1,9 @@ import os - from test._async_compat import mark_async_test from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME + import pytest -from neo4j import AsyncGraphDatabase, AsyncDriver +from neo4j import AsyncDriver, AsyncGraphDatabase from neo4j.debug import watch from neomodel import AsyncStructuredNode, StringProperty, config diff --git a/test/async_/test_database_management.py b/test/async_/test_database_management.py index 6d2ace9f..68dcccc9 100644 --- a/test/async_/test_database_management.py +++ b/test/async_/test_database_management.py @@ -1,5 +1,6 @@ -import pytest from test._async_compat import mark_async_test + +import pytest from neo4j.exceptions import AuthError from neomodel import ( diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py index f9f7b7b2..be9e376b 100644 --- a/test/async_/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -1,6 +1,7 @@ -from pytest import mark from test._async_compat import mark_async_test +from pytest import mark + from neomodel.async_.core import adb from neomodel.util import version_tag_to_integer diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py index 64d5b85d..dff17c75 100644 --- a/test/async_/test_driver_options.py +++ b/test/async_/test_driver_options.py @@ -1,5 +1,6 @@ -import pytest from test._async_compat import mark_async_test + +import pytest from neo4j.exceptions import ClientError from pytest import raises diff --git a/test/async_/test_hooks.py b/test/async_/test_hooks.py index 12643e78..09a87403 100644 --- a/test/async_/test_hooks.py +++ b/test/async_/test_hooks.py @@ -1,4 +1,5 @@ from test._async_compat import mark_async_test + from neomodel import AsyncStructuredNode, StringProperty HOOKS_CALLED = {} diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py index 6fd51488..2933f045 100644 --- a/test/async_/test_indexing.py +++ b/test/async_/test_indexing.py @@ -1,6 +1,7 @@ +from test._async_compat import mark_async_test + import pytest from pytest import raises -from test._async_compat import mark_async_test from neomodel import ( AsyncStructuredNode, diff --git a/test/async_/test_issue112.py b/test/async_/test_issue112.py index 12940992..8ba1ce03 100644 --- a/test/async_/test_issue112.py +++ b/test/async_/test_issue112.py @@ -1,4 +1,5 @@ from test._async_compat import mark_async_test + from neomodel import AsyncRelationshipTo, AsyncStructuredNode diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index 7444f0a4..1f62d80c 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -9,24 +9,24 @@ idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ -from test._async_compat import mark_async_test import random +from test._async_compat import mark_async_test import pytest from neomodel import ( - AsyncStructuredRel, + AsyncRelationshipTo, AsyncStructuredNode, + AsyncStructuredRel, DateTimeProperty, FloatProperty, + RelationshipClassNotDefined, + RelationshipClassRedefined, StringProperty, UniqueIdProperty, - AsyncRelationshipTo, - RelationshipClassRedefined, - RelationshipClassNotDefined, ) from neomodel.async_.core import adb -from neomodel.exceptions import NodeClassNotDefined, NodeClassAlreadyDefined +from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined try: basestring diff --git a/test/async_/test_issue600.py b/test/async_/test_issue600.py index a35fb8f6..5f66f39e 100644 --- a/test/async_/test_issue600.py +++ b/test/async_/test_issue600.py @@ -5,7 +5,8 @@ """ from test._async_compat import mark_async_test -from neomodel import AsyncStructuredNode, AsyncRelationship, AsyncStructuredRel + +from neomodel import AsyncRelationship, AsyncStructuredNode, AsyncStructuredRel try: basestring diff --git a/test/async_/test_label_drop.py b/test/async_/test_label_drop.py index 834f47e3..fa4b6106 100644 --- a/test/async_/test_label_drop.py +++ b/test/async_/test_label_drop.py @@ -1,6 +1,7 @@ -from neo4j.exceptions import ClientError from test._async_compat import mark_async_test +from neo4j.exceptions import ClientError + from neomodel import AsyncStructuredNode, StringProperty from neomodel.async_.core import adb diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py index 09520b70..3ccac1b7 100644 --- a/test/async_/test_label_install.py +++ b/test/async_/test_label_install.py @@ -1,6 +1,7 @@ -import pytest from test._async_compat import mark_async_test +import pytest + from neomodel import ( AsyncRelationshipTo, AsyncStructuredNode, diff --git a/test/async_/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py index a836b969..ceda47b7 100644 --- a/test/async_/test_migration_neo4j_5.py +++ b/test/async_/test_migration_neo4j_5.py @@ -1,6 +1,7 @@ -import pytest from test._async_compat import mark_async_test +import pytest + from neomodel import ( AsyncRelationshipTo, AsyncStructuredNode, diff --git a/test/async_/test_models.py b/test/async_/test_models.py index 4a83a311..39b9026b 100644 --- a/test/async_/test_models.py +++ b/test/async_/test_models.py @@ -1,8 +1,8 @@ from __future__ import print_function from datetime import datetime - from test._async_compat import mark_async_test + from pytest import raises from neomodel import ( diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py index ba664203..91120675 100644 --- a/test/async_/test_paths.py +++ b/test/async_/test_paths.py @@ -1,4 +1,5 @@ from test._async_compat import mark_async_test + from neomodel import ( AsyncNeomodelPath, AsyncRelationshipTo, diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 65239fac..9156250b 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -1,6 +1,6 @@ from datetime import date, datetime - from test._async_compat import mark_async_test + from pytest import mark, raises from pytz import timezone diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py index 40b8649c..f1ae8fdd 100644 --- a/test/async_/test_relationships.py +++ b/test/async_/test_relationships.py @@ -1,6 +1,7 @@ -from pytest import raises from test._async_compat import mark_async_test +from pytest import raises + from neomodel import ( AsyncOne, AsyncRelationship, diff --git a/test/async_/test_relative_relationships.py b/test/async_/test_relative_relationships.py index 5b8dbefa..371be944 100644 --- a/test/async_/test_relative_relationships.py +++ b/test/async_/test_relative_relationships.py @@ -1,6 +1,7 @@ -from neomodel import AsyncRelationshipTo, AsyncStructuredNode, StringProperty -from test.async_.test_relationships import Country from test._async_compat import mark_async_test +from test.async_.test_relationships import Country + +from neomodel import AsyncRelationshipTo, AsyncStructuredNode, StringProperty class Cat(AsyncStructuredNode): diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py new file mode 100644 index 00000000..39997d67 --- /dev/null +++ b/test/async_/test_transactions.py @@ -0,0 +1,196 @@ +from test._async_compat import mark_async_test + +import pytest +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, TransactionError +from pytest import raises + +from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty +from neomodel.async_.core import adb + + +class APerson(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + +@mark_async_test +async def test_rollback_and_commit_transaction(): + for p in await APerson.nodes.all(): + await p.delete() + + await APerson(name="Roger").save() + + await adb.begin() + await APerson(name="Terry S").save() + await adb.rollback() + + assert len(await APerson.nodes.all()) == 1 + + await adb.begin() + await APerson(name="Terry S").save() + await adb.commit() + + assert len(await APerson.nodes.all()) == 2 + + +@adb.transaction +async def in_a_tx(*names): + for n in names: + await APerson(name=n).save() + + +# TODO : understand how to make @adb.transaction work with async +@mark_async_test +async def test_transaction_decorator(): + await adb.install_labels(APerson) + for p in await APerson.nodes.all(): + await p.delete() + + # should work + await in_a_tx("Roger") + assert True + + # should bail but raise correct error + with raises(UniqueProperty): + await in_a_tx("Jim", "Roger") + + assert "Jim" not in [p.name async for p in await APerson.nodes.all()] + + +@mark_async_test +async def test_transaction_as_a_context(): + with adb.transaction: + await APerson(name="Tim").save() + + assert await APerson.nodes.filter(name="Tim").all() + + with raises(UniqueProperty): + with adb.transaction: + await APerson(name="Tim").save() + + +@mark_async_test +async def test_query_inside_transaction(): + for p in await APerson.nodes.all(): + await p.delete() + + with adb.transaction: + await APerson(name="Alice").save() + await APerson(name="Bob").save() + + assert len([p.name for p in await APerson.nodes.all()]) == 2 + + +@mark_async_test +async def test_read_transaction(): + await APerson(name="Johnny").save() + + with adb.read_transaction: + people = await APerson.nodes.all() + assert people + + with raises(TransactionError): + with adb.read_transaction: + with raises(ClientError) as e: + await APerson(name="Gina").save() + assert e.value.code == "Neo.ClientError.Statement.AccessMode" + + +@mark_async_test +async def test_write_transaction(): + with adb.write_transaction: + await APerson(name="Amelia").save() + + amelia = await APerson.nodes.get(name="Amelia") + assert amelia + + +@mark_async_test +async def double_transaction(): + await adb.begin() + with raises(SystemError, match=r"Transaction in progress"): + await adb.begin() + + await adb.rollback() + + +@adb.transaction.with_bookmark +async def in_a_tx(*names): + for n in names: + await APerson(name=n).save() + + +@mark_async_test +async def test_bookmark_transaction_decorator(): + for p in await APerson.nodes.all(): + await p.delete() + + # should work + result, bookmarks = await in_a_tx("Ruth", bookmarks=None) + assert result is None + assert isinstance(bookmarks, Bookmarks) + + # should bail but raise correct error + with raises(UniqueProperty): + await in_a_tx("Jane", "Ruth") + + assert "Jane" not in [p.name for p in await APerson.nodes.all()] + + +@mark_async_test +async def test_bookmark_transaction_as_a_context(): + with adb.transaction as transaction: + APerson(name="Tanya").save() + assert isinstance(transaction.last_bookmark, Bookmarks) + + assert APerson.nodes.filter(name="Tanya") + + with raises(UniqueProperty): + with adb.transaction as transaction: + APerson(name="Tanya").save() + assert not hasattr(transaction, "last_bookmark") + + +@pytest.fixture +async def spy_on_db_begin(monkeypatch): + spy_calls = [] + original_begin = adb.begin + + def begin_spy(*args, **kwargs): + spy_calls.append((args, kwargs)) + return original_begin(*args, **kwargs) + + monkeypatch.setattr(adb, "begin", begin_spy) + return spy_calls + + +@mark_async_test +async def test_bookmark_passed_in_to_context(spy_on_db_begin): + transaction = adb.transaction + with transaction: + pass + + assert (await spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) + last_bookmark = transaction.last_bookmark + + transaction.bookmarks = last_bookmark + with transaction: + pass + assert spy_on_db_begin[-1] == ( + (), + {"access_mode": None, "bookmarks": last_bookmark}, + ) + + +@mark_async_test +async def test_query_inside_bookmark_transaction(): + for p in await APerson.nodes.all(): + await p.delete() + + with adb.transaction as transaction: + await APerson(name="Alice").save() + await APerson(name="Bob").save() + + assert len([p.name for p in await APerson.nodes.all()]) == 2 + + assert isinstance(transaction.last_bookmark, Bookmarks) diff --git a/test/sync_/test_alias.py b/test/sync_/test_alias.py index c0d084a2..f266eb82 100644 --- a/test/sync_/test_alias.py +++ b/test/sync_/test_alias.py @@ -1,6 +1,6 @@ from test._async_compat import mark_sync_test -from neomodel import AliasProperty, StructuredNode, StringProperty +from neomodel import AliasProperty, StringProperty, StructuredNode class MagicProperty(AliasProperty): diff --git a/test/sync_/test_batch.py b/test/sync_/test_batch.py index 8d5586e9..ca28626f 100644 --- a/test/sync_/test_batch.py +++ b/test/sync_/test_batch.py @@ -3,11 +3,11 @@ from pytest import raises from neomodel import ( + IntegerProperty, RelationshipFrom, RelationshipTo, - StructuredNode, - IntegerProperty, StringProperty, + StructuredNode, UniqueIdProperty, config, ) @@ -24,15 +24,11 @@ class UniqueUser(StructuredNode): @mark_sync_test def test_unique_id_property_batch(): - users = UniqueUser.create( - {"name": "bob", "age": 2}, {"name": "ben", "age": 3} - ) + users = UniqueUser.create({"name": "bob", "age": 2}, {"name": "ben", "age": 3}) assert users[0].uid != users[1].uid - users = UniqueUser.get_or_create( - {"uid": users[0].uid}, {"name": "bill", "age": 4} - ) + users = UniqueUser.get_or_create({"uid": users[0].uid}, {"name": "bill", "age": 4}) assert users[0].name == "bob" assert users[1].uid diff --git a/test/sync_/test_cardinality.py b/test/sync_/test_cardinality.py index f3c27360..2971fb33 100644 --- a/test/sync_/test_cardinality.py +++ b/test/sync_/test_cardinality.py @@ -1,19 +1,19 @@ from test._async_compat import mark_sync_test + from pytest import raises from neomodel import ( - One, - OneOrMore, - RelationshipTo, - StructuredNode, - ZeroOrOne, AttemptedCardinalityViolation, CardinalityViolation, IntegerProperty, + One, + OneOrMore, + RelationshipTo, StringProperty, + StructuredNode, ZeroOrMore, + ZeroOrOne, ) - from neomodel.sync_.core import db @@ -32,13 +32,9 @@ class Car(StructuredNode): class Monkey(StructuredNode): name = StringProperty() dryers = RelationshipTo("HairDryer", "OWNS_DRYER", cardinality=ZeroOrMore) - driver = RelationshipTo( - "ScrewDriver", "HAS_SCREWDRIVER", cardinality=ZeroOrOne - ) + driver = RelationshipTo("ScrewDriver", "HAS_SCREWDRIVER", cardinality=ZeroOrOne) car = RelationshipTo("Car", "HAS_CAR", cardinality=OneOrMore) - toothbrush = RelationshipTo( - "ToothBrush", "HAS_TOOTHBRUSH", cardinality=One - ) + toothbrush = RelationshipTo("ToothBrush", "HAS_TOOTHBRUSH", cardinality=One) class ToothBrush(StructuredNode): diff --git a/test/sync_/test_connection.py b/test/sync_/test_connection.py index 4f4e57c7..666321a1 100644 --- a/test/sync_/test_connection.py +++ b/test/sync_/test_connection.py @@ -1,12 +1,12 @@ import os - from test._async_compat import mark_sync_test from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME + import pytest -from neo4j import GraphDatabase, Driver +from neo4j import Driver, GraphDatabase from neo4j.debug import watch -from neomodel import StructuredNode, StringProperty, config +from neomodel import StringProperty, StructuredNode, config from neomodel.sync_.core import db @@ -52,9 +52,7 @@ def test_set_connection_driver_works(): # Test connection using a driver db.set_connection( - driver=GraphDatabase().driver( - NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) - ) + driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) ) assert Pastry(name="Croissant").save() @@ -105,9 +103,7 @@ def test_connect_to_non_default_database(): # driver init db.set_connection( - driver=GraphDatabase().driver( - NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) - ) + driver=GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) ) assert get_current_database_name() == "pastries" diff --git a/test/sync_/test_cypher.py b/test/sync_/test_cypher.py index ac1b026c..46beed7e 100644 --- a/test/sync_/test_cypher.py +++ b/test/sync_/test_cypher.py @@ -6,7 +6,7 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StructuredNode, StringProperty +from neomodel import StringProperty, StructuredNode from neomodel.sync_.core import db @@ -94,9 +94,7 @@ def test_pandas_integration(): # Test to_dataframe df = to_dataframe( - db.cypher_query( - "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" - ) + db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email") ) assert isinstance(df, DataFrame) @@ -105,9 +103,7 @@ def test_pandas_integration(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( - db.cypher_query( - "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" - ), + db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email"), index=df["email"], dtype=str, ) @@ -115,9 +111,7 @@ def test_pandas_integration(): assert df.index.inferred_type == "string" # Next test to_series - series = to_series( - db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") - ) + series = to_series(db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name")) assert isinstance(series, Series) assert series.shape == (2,) diff --git a/test/sync_/test_database_management.py b/test/sync_/test_database_management.py index 9b3f8bf2..82811b93 100644 --- a/test/sync_/test_database_management.py +++ b/test/sync_/test_database_management.py @@ -1,13 +1,14 @@ -import pytest from test._async_compat import mark_sync_test + +import pytest from neo4j.exceptions import AuthError from neomodel import ( + IntegerProperty, RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - IntegerProperty, - StringProperty, ) from neomodel.sync_.core import db diff --git a/test/sync_/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py index b2776af0..1d815eff 100644 --- a/test/sync_/test_dbms_awareness.py +++ b/test/sync_/test_dbms_awareness.py @@ -1,6 +1,7 @@ -from pytest import mark from test._async_compat import mark_sync_test +from pytest import mark + from neomodel.sync_.core import db from neomodel.util import version_tag_to_integer diff --git a/test/sync_/test_driver_options.py b/test/sync_/test_driver_options.py index cedb1ae0..5e5e12b9 100644 --- a/test/sync_/test_driver_options.py +++ b/test/sync_/test_driver_options.py @@ -1,10 +1,11 @@ -import pytest from test._async_compat import mark_sync_test + +import pytest from neo4j.exceptions import ClientError from pytest import raises -from neomodel.sync_.core import db from neomodel.exceptions import FeatureNotSupported +from neomodel.sync_.core import db @mark_sync_test diff --git a/test/sync_/test_exceptions.py b/test/sync_/test_exceptions.py index a422db87..fe8cfe36 100644 --- a/test/sync_/test_exceptions.py +++ b/test/sync_/test_exceptions.py @@ -1,7 +1,7 @@ import pickle from test._async_compat import mark_sync_test -from neomodel import StructuredNode, DoesNotExist, StringProperty +from neomodel import DoesNotExist, StringProperty, StructuredNode class EPerson(StructuredNode): diff --git a/test/sync_/test_hooks.py b/test/sync_/test_hooks.py index b3cbe864..a6f742e0 100644 --- a/test/sync_/test_hooks.py +++ b/test/sync_/test_hooks.py @@ -1,5 +1,6 @@ from test._async_compat import mark_sync_test -from neomodel import StructuredNode, StringProperty + +from neomodel import StringProperty, StructuredNode HOOKS_CALLED = {} diff --git a/test/sync_/test_indexing.py b/test/sync_/test_indexing.py index 1eda3d21..db1c3256 100644 --- a/test/sync_/test_indexing.py +++ b/test/sync_/test_indexing.py @@ -1,15 +1,11 @@ +from test._async_compat import mark_sync_test + import pytest from pytest import raises -from test._async_compat import mark_sync_test -from neomodel import ( - StructuredNode, - IntegerProperty, - StringProperty, - UniqueProperty, -) -from neomodel.sync_.core import db +from neomodel import IntegerProperty, StringProperty, StructuredNode, UniqueProperty from neomodel.exceptions import ConstraintValidationFailed +from neomodel.sync_.core import db class Human(StructuredNode): diff --git a/test/sync_/test_issue112.py b/test/sync_/test_issue112.py index 26605018..f580f146 100644 --- a/test/sync_/test_issue112.py +++ b/test/sync_/test_issue112.py @@ -1,4 +1,5 @@ from test._async_compat import mark_sync_test + from neomodel import RelationshipTo, StructuredNode diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index 652a0e69..6fcc9b99 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -9,24 +9,24 @@ idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ -from test._async_compat import mark_sync_test import random +from test._async_compat import mark_sync_test import pytest from neomodel import ( - StructuredRel, - StructuredNode, DateTimeProperty, FloatProperty, + RelationshipClassNotDefined, + RelationshipClassRedefined, + RelationshipTo, StringProperty, + StructuredNode, + StructuredRel, UniqueIdProperty, - RelationshipTo, - RelationshipClassRedefined, - RelationshipClassNotDefined, ) +from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined from neomodel.sync_.core import db -from neomodel.exceptions import NodeClassNotDefined, NodeClassAlreadyDefined try: basestring @@ -99,17 +99,11 @@ def test_automatic_result_resolution(): """ # Create a few entities - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] - B = ( - TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) - )[0] - C = ( - TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0] + C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0] # Add connections A.friends_with.connect(B) @@ -135,25 +129,13 @@ def test_recursive_automatic_result_resolution(): # Create a few entities A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpier", "expertise": "Grumpiness"} - ) - )[0] - B = ( - TechnicalPerson.get_or_create( - {"name": "Happier", "expertise": "Grumpiness"} - ) - )[0] - C = ( - TechnicalPerson.get_or_create( - {"name": "Sleepier", "expertise": "Pillows"} - ) - )[0] - D = ( - TechnicalPerson.get_or_create( - {"name": "Sneezier", "expertise": "Pillows"} - ) + TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"}) )[0] + B = (TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"}))[ + 0 + ] + C = (TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"}))[0] + D = (TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"}))[0] # Retrieve mixed results, both at the top level and nested L, _ = db.cypher_query( @@ -189,17 +171,11 @@ def test_validation_with_inheritance_from_db(): # Create a few entities # Technical Persons - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] - B = ( - TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) - )[0] - C = ( - TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0] + C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0] # Pilot Persons D = ( @@ -253,17 +229,11 @@ def test_validation_enforcement_to_db(): # Create a few entities # Technical Persons - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] - B = ( - TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) - )[0] - C = ( - TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0] + C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0] # Pilot Persons D = ( @@ -313,11 +283,9 @@ class RandomPerson(BasePerson): randomness = FloatProperty(default=random.random) # A Technical Person... - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] # A Random Person... B = (RandomPerson.get_or_create({"name": "Mad Hatter"}))[0] @@ -329,11 +297,9 @@ class RandomPerson(BasePerson): del db._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] # Now try to instantiate a RandomPerson - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] with pytest.raises( NodeClassNotDefined, match=r"Node with labels .* does not resolve to any of the known objects.*", @@ -360,15 +326,11 @@ class UltraTechnicalPerson(SuperTechnicalPerson): ultraness = FloatProperty(default=3.1415928) # Create a TechnicalPerson... - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] # ...that is connected to an UltraTechnicalPerson - F = UltraTechnicalPerson( - name="Chewbaka", expertise="Aarrr wgh ggwaaah" - ).save() + F = UltraTechnicalPerson(name="Chewbaka", expertise="Aarrr wgh ggwaaah").save() A.friends_with.connect(F) # Forget about the UltraTechnicalPerson @@ -386,11 +348,9 @@ class UltraTechnicalPerson(SuperTechnicalPerson): # Recall a TechnicalPerson and enumerate its friends. # One of them is UltraTechnicalPerson which would be returned as a valid # node to a friends_with query but is currently unknown to the node class registry. - A = ( - TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - ) - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] with pytest.raises(NodeClassNotDefined): friends = A.friends_with.all() for some_friend in friends: diff --git a/test/sync_/test_issue600.py b/test/sync_/test_issue600.py index d88d8eb0..f6b5a10b 100644 --- a/test/sync_/test_issue600.py +++ b/test/sync_/test_issue600.py @@ -5,7 +5,8 @@ """ from test._async_compat import mark_sync_test -from neomodel import StructuredNode, Relationship, StructuredRel + +from neomodel import Relationship, StructuredNode, StructuredRel try: basestring diff --git a/test/sync_/test_label_drop.py b/test/sync_/test_label_drop.py index 016f72c1..55db5fec 100644 --- a/test/sync_/test_label_drop.py +++ b/test/sync_/test_label_drop.py @@ -1,7 +1,8 @@ -from neo4j.exceptions import ClientError from test._async_compat import mark_sync_test -from neomodel import StructuredNode, StringProperty +from neo4j.exceptions import ClientError + +from neomodel import StringProperty, StructuredNode from neomodel.sync_.core import db diff --git a/test/sync_/test_label_install.py b/test/sync_/test_label_install.py index 22d70348..60309dfc 100644 --- a/test/sync_/test_label_install.py +++ b/test/sync_/test_label_install.py @@ -1,15 +1,16 @@ -import pytest from test._async_compat import mark_sync_test +import pytest + from neomodel import ( RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - StringProperty, UniqueIdProperty, ) -from neomodel.sync_.core import db from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported +from neomodel.sync_.core import db class NodeWithIndex(StructuredNode): @@ -117,9 +118,7 @@ def test_install_labels_db_property(capsys): _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") -@pytest.mark.skipif( - db.version_is_higher_than("5.7"), reason="Not supported before 5.7" -) +@pytest.mark.skipif(db.version_is_higher_than("5.7"), reason="Not supported before 5.7") def test_relationship_unique_index_not_supported(): class UniqueIndexRelationship(StructuredRel): name = StringProperty(unique_index=True) @@ -168,9 +167,7 @@ class NodeWithUniqueIndexRelationship(StructuredNode): rel2 = node1.has_rel.connect(node3, {"name": "rel1"}) -def _drop_constraints_for_label_and_property( - label: str = None, property: str = None -): +def _drop_constraints_for_label_and_property(label: str = None, property: str = None): results, meta = db.cypher_query("SHOW CONSTRAINTS") results_as_dict = [dict(zip(meta, row)) for row in results] constraint_names = [ diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 8fd02eb6..7dd3d489 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -5,22 +5,17 @@ from neomodel import ( INCOMING, - RelationshipFrom, - RelationshipTo, - StructuredNode, - StructuredRel, DateTimeProperty, IntegerProperty, Q, + RelationshipFrom, + RelationshipTo, StringProperty, -) -from neomodel.sync_.match import ( - NodeSet, - QueryBuilder, - Traversal, - Optional, + StructuredNode, + StructuredRel, ) from neomodel.exceptions import MultipleNodesReturned +from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal class SupplierRel(StructuredRel): @@ -36,9 +31,7 @@ class Supplier(StructuredNode): class Species(StructuredNode): name = StringProperty() - coffees = RelationshipFrom( - "Coffee", "COFFEE SPECIES", model=StructuredRel - ) + coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) class Coffee(StructuredNode): @@ -125,9 +118,7 @@ def test_simple_traverse_with_filter(): tesco = Supplier(name="Sainsburys", delivery_cost=2).save() nescafe.suppliers.connect(tesco) - qb = QueryBuilder( - NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) - ) + qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) results = qb.build_ast()._execute() @@ -489,9 +480,7 @@ def test_traversal_filter_left_hand_statement(): nescafe_gold.suppliers.connect(lidl) lidl_supplier = ( - NodeSet(Coffee.nodes.filter(price=11).suppliers) - .filter(delivery_cost=3) - .all() + NodeSet(Coffee.nodes.filter(price=11).suppliers).filter(delivery_cost=3).all() ) assert lidl in lidl_supplier diff --git a/test/sync_/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py index 3198abf8..8bc2680d 100644 --- a/test/sync_/test_migration_neo4j_5.py +++ b/test/sync_/test_migration_neo4j_5.py @@ -1,12 +1,13 @@ -import pytest from test._async_compat import mark_sync_test +import pytest + from neomodel import ( + IntegerProperty, RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - IntegerProperty, - StringProperty, ) from neomodel.sync_.core import db diff --git a/test/sync_/test_models.py b/test/sync_/test_models.py index 13db07fe..dc3ff735 100644 --- a/test/sync_/test_models.py +++ b/test/sync_/test_models.py @@ -1,19 +1,19 @@ from __future__ import print_function from datetime import datetime - from test._async_compat import mark_sync_test + from pytest import raises from neomodel import ( - StructuredNode, - StructuredRel, DateProperty, IntegerProperty, StringProperty, + StructuredNode, + StructuredRel, ) -from neomodel.sync_.core import db from neomodel.exceptions import RequiredProperty, UniqueProperty +from neomodel.sync_.core import db class User(StructuredNode): diff --git a/test/sync_/test_multiprocessing.py b/test/sync_/test_multiprocessing.py index 830c5d6a..861b0af2 100644 --- a/test/sync_/test_multiprocessing.py +++ b/test/sync_/test_multiprocessing.py @@ -1,7 +1,7 @@ from multiprocessing.pool import ThreadPool as Pool from test._async_compat import mark_sync_test -from neomodel import StructuredNode, StringProperty +from neomodel import StringProperty, StructuredNode from neomodel.sync_.core import db diff --git a/test/sync_/test_paths.py b/test/sync_/test_paths.py index 13201439..b8f325f8 100644 --- a/test/sync_/test_paths.py +++ b/test/sync_/test_paths.py @@ -1,11 +1,12 @@ from test._async_compat import mark_sync_test + from neomodel import ( + IntegerProperty, NeomodelPath, RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - IntegerProperty, - StringProperty, UniqueIdProperty, ) from neomodel.sync_.core import db diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 2d2b3873..581924aa 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -1,11 +1,10 @@ from datetime import date, datetime - from test._async_compat import mark_sync_test + from pytest import mark, raises from pytz import timezone from neomodel import StructuredNode -from neomodel.sync_.core import db from neomodel.exceptions import ( DeflateError, InflateError, @@ -25,6 +24,7 @@ StringProperty, UniqueIdProperty, ) +from neomodel.sync_.core import db from neomodel.util import _get_node_properties @@ -247,9 +247,7 @@ class TestDBNamePropertyNode(StructuredNode): assert not "name_" in node_properties assert not hasattr(x, "name") assert hasattr(x, "name_") - assert (TestDBNamePropertyNode.nodes.filter(name_="jim").all())[ - 0 - ].name_ == x.name_ + assert (TestDBNamePropertyNode.nodes.filter(name_="jim").all())[0].name_ == x.name_ assert (TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ x.delete() diff --git a/test/sync_/test_relationship_models.py b/test/sync_/test_relationship_models.py index f1ecbe7e..5b2e75d7 100644 --- a/test/sync_/test_relationship_models.py +++ b/test/sync_/test_relationship_models.py @@ -5,13 +5,13 @@ from pytest import raises from neomodel import ( + DateTimeProperty, + DeflateError, Relationship, RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - DateTimeProperty, - DeflateError, - StringProperty, ) HOOKS_CALLED = {"pre_save": 0, "post_save": 0} @@ -70,9 +70,7 @@ def test_direction_connect_with_rel_model(): paul = Badger(name="Paul the badger").save() ian = Stoat(name="Ian the stoat").save() - rel = ian.hates.connect( - paul, {"reason": "thinks paul should bath more often"} - ) + rel = ian.hates.connect(paul, {"reason": "thinks paul should bath more often"}) assert isinstance(rel.since, datetime) assert isinstance(rel, FriendRel) assert rel.reason.startswith("thinks") diff --git a/test/sync_/test_relationships.py b/test/sync_/test_relationships.py index 462c966e..44c6010a 100644 --- a/test/sync_/test_relationships.py +++ b/test/sync_/test_relationships.py @@ -1,16 +1,17 @@ -from pytest import raises from test._async_compat import mark_sync_test +from pytest import raises + from neomodel import ( + IntegerProperty, One, + Q, Relationship, RelationshipFrom, RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - IntegerProperty, - Q, - StringProperty, ) from neomodel.sync_.core import db diff --git a/test/sync_/test_relative_relationships.py b/test/sync_/test_relative_relationships.py index 07a2e6ca..a01e28f9 100644 --- a/test/sync_/test_relative_relationships.py +++ b/test/sync_/test_relative_relationships.py @@ -1,6 +1,7 @@ -from neomodel import RelationshipTo, StructuredNode, StringProperty -from test.sync_.test_relationships import Country from test._async_compat import mark_sync_test +from test.sync_.test_relationships import Country + +from neomodel import RelationshipTo, StringProperty, StructuredNode class Cat(StructuredNode): diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py new file mode 100644 index 00000000..309637f6 --- /dev/null +++ b/test/sync_/test_transactions.py @@ -0,0 +1,196 @@ +from test._async_compat import mark_sync_test + +import pytest +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, TransactionError +from pytest import raises + +from neomodel import StringProperty, StructuredNode, UniqueProperty +from neomodel.sync_.core import db + + +class APerson(StructuredNode): + name = StringProperty(unique_index=True) + + +@mark_sync_test +def test_rollback_and_commit_transaction(): + for p in APerson.nodes.all(): + p.delete() + + APerson(name="Roger").save() + + db.begin() + APerson(name="Terry S").save() + db.rollback() + + assert len(APerson.nodes.all()) == 1 + + db.begin() + APerson(name="Terry S").save() + db.commit() + + assert len(APerson.nodes.all()) == 2 + + +@db.transaction +def in_a_tx(*names): + for n in names: + APerson(name=n).save() + + +# TODO : understand how to make @adb.transaction work with async +@mark_sync_test +def test_transaction_decorator(): + db.install_labels(APerson) + for p in APerson.nodes.all(): + p.delete() + + # should work + in_a_tx("Roger") + assert True + + # should bail but raise correct error + with raises(UniqueProperty): + in_a_tx("Jim", "Roger") + + assert "Jim" not in [p.name for p in APerson.nodes.all()] + + +@mark_sync_test +def test_transaction_as_a_context(): + with db.transaction: + APerson(name="Tim").save() + + assert APerson.nodes.filter(name="Tim").all() + + with raises(UniqueProperty): + with db.transaction: + APerson(name="Tim").save() + + +@mark_sync_test +def test_query_inside_transaction(): + for p in APerson.nodes.all(): + p.delete() + + with db.transaction: + APerson(name="Alice").save() + APerson(name="Bob").save() + + assert len([p.name for p in APerson.nodes.all()]) == 2 + + +@mark_sync_test +def test_read_transaction(): + APerson(name="Johnny").save() + + with db.read_transaction: + people = APerson.nodes.all() + assert people + + with raises(TransactionError): + with db.read_transaction: + with raises(ClientError) as e: + APerson(name="Gina").save() + assert e.value.code == "Neo.ClientError.Statement.AccessMode" + + +@mark_sync_test +def test_write_transaction(): + with db.write_transaction: + APerson(name="Amelia").save() + + amelia = APerson.nodes.get(name="Amelia") + assert amelia + + +@mark_sync_test +def double_transaction(): + db.begin() + with raises(SystemError, match=r"Transaction in progress"): + db.begin() + + db.rollback() + + +@db.transaction.with_bookmark +def in_a_tx(*names): + for n in names: + APerson(name=n).save() + + +@mark_sync_test +def test_bookmark_transaction_decorator(): + for p in APerson.nodes.all(): + p.delete() + + # should work + result, bookmarks = in_a_tx("Ruth", bookmarks=None) + assert result is None + assert isinstance(bookmarks, Bookmarks) + + # should bail but raise correct error + with raises(UniqueProperty): + in_a_tx("Jane", "Ruth") + + assert "Jane" not in [p.name for p in APerson.nodes.all()] + + +@mark_sync_test +def test_bookmark_transaction_as_a_context(): + with db.transaction as transaction: + APerson(name="Tanya").save() + assert isinstance(transaction.last_bookmark, Bookmarks) + + assert APerson.nodes.filter(name="Tanya") + + with raises(UniqueProperty): + with db.transaction as transaction: + APerson(name="Tanya").save() + assert not hasattr(transaction, "last_bookmark") + + +@pytest.fixture +def spy_on_db_begin(monkeypatch): + spy_calls = [] + original_begin = db.begin + + def begin_spy(*args, **kwargs): + spy_calls.append((args, kwargs)) + return original_begin(*args, **kwargs) + + monkeypatch.setattr(db, "begin", begin_spy) + return spy_calls + + +@mark_sync_test +def test_bookmark_passed_in_to_context(spy_on_db_begin): + transaction = db.transaction + with transaction: + pass + + assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) + last_bookmark = transaction.last_bookmark + + transaction.bookmarks = last_bookmark + with transaction: + pass + assert spy_on_db_begin[-1] == ( + (), + {"access_mode": None, "bookmarks": last_bookmark}, + ) + + +@mark_sync_test +def test_query_inside_bookmark_transaction(): + for p in APerson.nodes.all(): + p.delete() + + with db.transaction as transaction: + APerson(name="Alice").save() + APerson(name="Bob").save() + + assert len([p.name for p in APerson.nodes.all()]) == 2 + + assert isinstance(transaction.last_bookmark, Bookmarks) diff --git a/test/test_scripts.py b/test/test_scripts.py index 9f0690d9..5583af47 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -4,9 +4,9 @@ from neomodel import ( RelationshipTo, + StringProperty, StructuredNode, StructuredRel, - StringProperty, config, ) from neomodel.sync_.core import db From c55ad9103954601406043421af4c6e2192d1d174 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 14:08:31 +0100 Subject: [PATCH 29/73] Fix import error test contrib --- test/test_contrib/test_spatial_properties.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_contrib/test_spatial_properties.py b/test/test_contrib/test_spatial_properties.py index 89b8641a..0bd854e9 100644 --- a/test/test_contrib/test_spatial_properties.py +++ b/test/test_contrib/test_spatial_properties.py @@ -4,7 +4,6 @@ For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374 """ -import os import random import neo4j.spatial @@ -12,7 +11,7 @@ import neomodel import neomodel.contrib.spatial_properties -from neomodel.test_spatial_datatypes import ( +from .test_spatial_datatypes import ( basic_type_assertions, check_and_skip_neo4j_least_version, ) From 681f4f6ba7f776a064deb173d9d15f10d0021b09 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 17:03:58 +0100 Subject: [PATCH 30/73] Migrate test_contrib to async --- bin/make-unasync | 7 ++++ test/test_contrib/{ => async_}/__init__.py | 0 .../async_/test_semi_structured.py | 35 +++++++++++++++++++ .../{ => async_}/test_spatial_datatypes.py | 0 .../{ => async_}/test_spatial_properties.py | 25 ++++++++----- test/test_contrib/test_semi_structured.py | 30 ---------------- 6 files changed, 58 insertions(+), 39 deletions(-) rename test/test_contrib/{ => async_}/__init__.py (100%) create mode 100644 test/test_contrib/async_/test_semi_structured.py rename test/test_contrib/{ => async_}/test_spatial_datatypes.py (100%) rename test/test_contrib/{ => async_}/test_spatial_properties.py (94%) delete mode 100644 test/test_contrib/test_semi_structured.py diff --git a/bin/make-unasync b/bin/make-unasync index 5acdf5e7..5c8fca0b 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -17,6 +17,8 @@ ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_" +ASYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "async_" +SYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "sync_" UNASYNC_SUFFIX = ".unasync" PY_FILE_EXTENSIONS = {".py"} @@ -234,6 +236,11 @@ def apply_unasync(files): todir=str(SYNC_INTEGRATION_TEST_DIR), additional_replacements=additional_test_replacements, ), + CustomRule( + fromdir=str(ASYNC_INTEGRATION_TEST_CONTRIB_DIR), + todir=str(SYNC_INTEGRATION_TEST_CONTRIB_DIR), + additional_replacements=additional_test_replacements, + ), ] if not files: diff --git a/test/test_contrib/__init__.py b/test/test_contrib/async_/__init__.py similarity index 100% rename from test/test_contrib/__init__.py rename to test/test_contrib/async_/__init__.py diff --git a/test/test_contrib/async_/test_semi_structured.py b/test/test_contrib/async_/test_semi_structured.py new file mode 100644 index 00000000..3b88fcad --- /dev/null +++ b/test/test_contrib/async_/test_semi_structured.py @@ -0,0 +1,35 @@ +from test._async_compat import mark_async_test + +from neomodel import IntegerProperty, StringProperty +from neomodel.contrib import AsyncSemiStructuredNode + + +class UserProf(AsyncSemiStructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +class Dummy(AsyncSemiStructuredNode): + pass + + +@mark_async_test +async def test_to_save_to_model_with_required_only(): + u = UserProf(email="dummy@test.com") + assert await u.save() + + +@mark_async_test +async def test_save_to_model_with_extras(): + u = UserProf(email="jim@test.com", age=3, bar=99) + u.foo = True + assert await u.save() + u = await UserProf.nodes.get(age=3) + assert u.foo is True + assert u.bar == 99 + + +@mark_async_test +async def test_save_empty_model(): + dummy = Dummy() + assert await dummy.save() diff --git a/test/test_contrib/test_spatial_datatypes.py b/test/test_contrib/async_/test_spatial_datatypes.py similarity index 100% rename from test/test_contrib/test_spatial_datatypes.py rename to test/test_contrib/async_/test_spatial_datatypes.py diff --git a/test/test_contrib/test_spatial_properties.py b/test/test_contrib/async_/test_spatial_properties.py similarity index 94% rename from test/test_contrib/test_spatial_properties.py rename to test/test_contrib/async_/test_spatial_properties.py index 0bd854e9..a6103639 100644 --- a/test/test_contrib/test_spatial_properties.py +++ b/test/test_contrib/async_/test_spatial_properties.py @@ -5,12 +5,14 @@ """ import random +from test._async_compat import mark_async_test import neo4j.spatial import pytest import neomodel import neomodel.contrib.spatial_properties + from .test_spatial_datatypes import ( basic_type_assertions, check_and_skip_neo4j_least_version, @@ -154,7 +156,8 @@ def test_deflate(): ) -def test_default_value(): +@mark_async_test +async def test_default_value(): """ Tests that the default value passing mechanism works as expected with NeomodelPoint values. :return: @@ -181,10 +184,12 @@ class LocalisableEntity(neomodel.AsyncStructuredNode): ) # Save an object - an_object = LocalisableEntity().save() + an_object = await LocalisableEntity().save() coords = an_object.location.coords[0] # Retrieve it - retrieved_object = LocalisableEntity.nodes.get(identifier=an_object.identifier) + retrieved_object = await LocalisableEntity.nodes.get( + identifier=an_object.identifier + ) # Check against an independently created value assert ( retrieved_object.location @@ -192,7 +197,8 @@ class LocalisableEntity(neomodel.AsyncStructuredNode): ), ("Default value assignment failed.") -def test_array_of_points(): +@mark_async_test +async def test_array_of_points(): """ Tests that Arrays of Points work as expected. @@ -214,14 +220,14 @@ class AnotherLocalisableEntity(neomodel.AsyncStructuredNode): 340, "This version does not support spatial data types." ) - an_object = AnotherLocalisableEntity( + an_object = await AnotherLocalisableEntity( locations=[ neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), ] ).save() - retrieved_object = AnotherLocalisableEntity.nodes.get( + retrieved_object = await AnotherLocalisableEntity.nodes.get( identifier=an_object.identifier ) @@ -234,7 +240,8 @@ class AnotherLocalisableEntity(neomodel.AsyncStructuredNode): ], "Array of Points incorrect values." -def test_simple_storage_retrieval(): +@mark_async_test +async def test_simple_storage_retrieval(): """ Performs a simple Create, Retrieve via .save(), .get() which, due to the way Q objects operate, tests the __copy__, __deepcopy__ operations of NeomodelPoint. @@ -251,12 +258,12 @@ class TestStorageRetrievalProperty(neomodel.AsyncStructuredNode): 340, "This version does not support spatial data types." ) - a_restaurant = TestStorageRetrievalProperty( + a_restaurant = await TestStorageRetrievalProperty( description="Milliways", location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)), ).save() - a_property = TestStorageRetrievalProperty.nodes.get( + a_property = await TestStorageRetrievalProperty.nodes.get( location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)) ) diff --git a/test/test_contrib/test_semi_structured.py b/test/test_contrib/test_semi_structured.py deleted file mode 100644 index fe73a2bd..00000000 --- a/test/test_contrib/test_semi_structured.py +++ /dev/null @@ -1,30 +0,0 @@ -from neomodel import IntegerProperty, StringProperty -from neomodel.contrib import SemiStructuredNode - - -class UserProf(SemiStructuredNode): - email = StringProperty(unique_index=True, required=True) - age = IntegerProperty(index=True) - - -class Dummy(SemiStructuredNode): - pass - - -def test_to_save_to_model_with_required_only(): - u = UserProf(email="dummy@test.com") - assert u.save() - - -def test_save_to_model_with_extras(): - u = UserProf(email="jim@test.com", age=3, bar=99) - u.foo = True - assert u.save() - u = UserProf.nodes.get(age=3) - assert u.foo is True - assert u.bar == 99 - - -def test_save_empty_model(): - dummy = Dummy() - assert dummy.save() From a38e7ec8b4b631e8e0628a68785184b3577463e0 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 17:05:20 +0100 Subject: [PATCH 31/73] Fix make-unasync for test_contrib --- bin/make-unasync | 1 + test/test_contrib/sync_/__init__.py | 0 .../sync_/test_semi_structured.py | 35 ++ .../sync_/test_spatial_datatypes.py | 397 ++++++++++++++++++ .../sync_/test_spatial_properties.py | 289 +++++++++++++ 5 files changed, 722 insertions(+) create mode 100644 test/test_contrib/sync_/__init__.py create mode 100644 test/test_contrib/sync_/test_semi_structured.py create mode 100644 test/test_contrib/sync_/test_spatial_datatypes.py create mode 100644 test/test_contrib/sync_/test_spatial_properties.py diff --git a/bin/make-unasync b/bin/make-unasync index 5c8fca0b..72a66074 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -247,6 +247,7 @@ def apply_unasync(files): paths = list(ASYNC_DIR.rglob("*")) paths += list(ASYNC_CONTRIB_DIR.rglob("*")) paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*")) + paths += list(ASYNC_INTEGRATION_TEST_CONTRIB_DIR.rglob("*")) else: paths = [ROOT_DIR / Path(f) for f in files] filtered_paths = [] diff --git a/test/test_contrib/sync_/__init__.py b/test/test_contrib/sync_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_contrib/sync_/test_semi_structured.py b/test/test_contrib/sync_/test_semi_structured.py new file mode 100644 index 00000000..f4b9746b --- /dev/null +++ b/test/test_contrib/sync_/test_semi_structured.py @@ -0,0 +1,35 @@ +from test._async_compat import mark_sync_test + +from neomodel import IntegerProperty, StringProperty +from neomodel.contrib import SemiStructuredNode + + +class UserProf(SemiStructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +class Dummy(SemiStructuredNode): + pass + + +@mark_sync_test +def test_to_save_to_model_with_required_only(): + u = UserProf(email="dummy@test.com") + assert u.save() + + +@mark_sync_test +def test_save_to_model_with_extras(): + u = UserProf(email="jim@test.com", age=3, bar=99) + u.foo = True + assert u.save() + u = UserProf.nodes.get(age=3) + assert u.foo is True + assert u.bar == 99 + + +@mark_sync_test +def test_save_empty_model(): + dummy = Dummy() + assert dummy.save() diff --git a/test/test_contrib/sync_/test_spatial_datatypes.py b/test/test_contrib/sync_/test_spatial_datatypes.py new file mode 100644 index 00000000..b35ced3a --- /dev/null +++ b/test/test_contrib/sync_/test_spatial_datatypes.py @@ -0,0 +1,397 @@ +""" +Provides a test case for data types required by issue 374 - "Support for Point property type". + +At the moment, only one new datatype is offered: NeomodelPoint + +For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374 +""" + +import os + +import pytest +import shapely + +import neomodel +import neomodel.contrib.spatial_properties +from neomodel.util import version_tag_to_integer + + +def check_and_skip_neo4j_least_version(required_least_neo4j_version, message): + """ + Checks if the NEO4J_VERSION is at least `required_least_neo4j_version` and skips a test if not. + + WARNING: If the NEO4J_VERSION variable is not set, this function returns True, allowing the test to go ahead. + + :param required_least_neo4j_version: The least version to check. This must be the numberic representation of the + version. That is: '3.4.0' would be passed as 340. + :type required_least_neo4j_version: int + :param message: An informative message as to why the calling test had to be skipped. + :type message: str + :return: A boolean value of True if the version reported is at least `required_least_neo4j_version` + """ + if "NEO4J_VERSION" in os.environ: + if ( + version_tag_to_integer(os.environ["NEO4J_VERSION"]) + < required_least_neo4j_version + ): + pytest.skip( + "Neo4j version: {}. {}." + "Skipping test.".format(os.environ["NEO4J_VERSION"], message) + ) + + +def basic_type_assertions( + ground_truth, tested_object, test_description, check_neo4j_points=False +): + """ + Tests that `tested_object` has been created as intended. + + :param ground_truth: The object as it is supposed to have been created. + :type ground_truth: NeomodelPoint or neo4j.v1.spatial.Point + :param tested_object: The object as it results from one of the contructors. + :type tested_object: NeomodelPoint or neo4j.v1.spatial.Point + :param test_description: A brief description of the test being performed. + :type test_description: str + :param check_neo4j_points: Whether to assert between NeomodelPoint or neo4j.v1.spatial.Point objects. + :type check_neo4j_points: bool + :return: + """ + if check_neo4j_points: + assert isinstance( + tested_object, type(ground_truth) + ), "{} did not return Neo4j Point".format(test_description) + assert ( + tested_object.srid == ground_truth.srid + ), "{} does not have the expected SRID({})".format( + test_description, ground_truth.srid + ) + assert len(tested_object) == len( + ground_truth + ), "Dimensionality mismatch. Expected {}, had {}".format( + len(ground_truth.coords), len(tested_object.coords) + ) + else: + assert isinstance( + tested_object, type(ground_truth) + ), "{} did not return NeomodelPoint".format(test_description) + assert ( + tested_object.crs == ground_truth.crs + ), "{} does not have the expected CRS({})".format( + test_description, ground_truth.crs + ) + assert len(tested_object.coords[0]) == len( + ground_truth.coords[0] + ), "Dimensionality mismatch. Expected {}, had {}".format( + len(ground_truth.coords[0]), len(tested_object.coords[0]) + ) + + +# Object Construction +def test_coord_constructor(): + """ + Tests all the possible ways by which a NeomodelPoint can be instantiated successfully via passing coordinates. + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Implicit cartesian point with coords + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + basic_type_assertions( + ground_truth_object, + new_point, + "Implicit 2d cartesian point instantiation", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0) + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0, 0.0)) + basic_type_assertions( + ground_truth_object, + new_point, + "Implicit 3d cartesian point instantiation", + ) + + # Explicit geographical point with coords + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="wgs-84" + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="wgs-84" + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Explicit 2d geographical point with tuple of coords instantiation", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Explicit 3d geographical point with tuple of coords instantiation", + ) + + # Cartesian point with named arguments + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + x=0.0, y=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(x=0.0, y=0.0) + basic_type_assertions( + ground_truth_object, + new_point, + "Cartesian 2d point with named arguments", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + x=0.0, y=0.0, z=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(x=0.0, y=0.0, z=0.0) + basic_type_assertions( + ground_truth_object, + new_point, + "Cartesian 3d point with named arguments", + ) + + # Geographical point with named arguments + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0 + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Geographical 2d point with named arguments", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0, height=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0, height=0.0 + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Geographical 3d point with named arguments", + ) + + +def test_copy_constructors(): + """ + Tests all the possible ways by which a NeomodelPoint can be instantiated successfully via a copy constructor call. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Instantiate from Shapely point + + # Implicit cartesian from shapely point + ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="cartesian" + ) + shapely_point = shapely.geometry.Point((0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(shapely_point) + basic_type_assertions( + ground_truth, new_point, "Implicit cartesian by shapely Point" + ) + + # Explicit geographical by shapely point + ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + shapely_point = shapely.geometry.Point((0.0, 0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + shapely_point, crs="wgs-84-3d" + ) + basic_type_assertions( + ground_truth, new_point, "Explicit geographical by shapely Point" + ) + + # Copy constructor for NeomodelPoints + ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + other_neomodel_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(other_neomodel_point) + basic_type_assertions(ground_truth, new_point, "NeomodelPoint copy constructor") + + +def test_prohibited_constructor_forms(): + """ + Tests all the possible forms by which construction of NeomodelPoints should fail. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Absurd CRS + with pytest.raises(ValueError, match=r"Invalid CRS\(blue_hotel\)"): + _ = neomodel.contrib.spatial_properties.NeomodelPoint((0, 0), crs="blue_hotel") + + # Absurd coord dimensionality + with pytest.raises( + ValueError, + ): + _ = neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0, 0, 0, 0, 0, 0), crs="cartesian" + ) + + # Absurd datatype passed to copy constructor + with pytest.raises( + TypeError, + ): + _ = neomodel.contrib.spatial_properties.NeomodelPoint( + "it don't mean a thing if it ain't got that swing", + crs="cartesian", + ) + + # Trying to instantiate a point with any of BOTH x,y,z or longitude, latitude, height + with pytest.raises(ValueError, match="Invalid instantiation via arguments"): + _ = neomodel.contrib.spatial_properties.NeomodelPoint( + x=0.0, + y=0.0, + longitude=0.0, + latitude=2.0, + height=-2.0, + crs="cartesian", + ) + + # Trying to instantiate a point with absolutely NO parameters + with pytest.raises(ValueError, match="Invalid instantiation via no arguments"): + _ = neomodel.contrib.spatial_properties.NeomodelPoint() + + +def test_property_accessors_depending_on_crs_shapely_lt_2(): + """ + Tests that points are accessed via their respective accessors. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Check the version of Shapely installed to run the appropriate tests: + try: + from shapely import __version__ + except ImportError: + pytest.skip("Shapely not installed") + + if int("".join(__version__.split(".")[0:3])) >= 200: + pytest.skip("Shapely 2 is installed, skipping earlier version test") + + # Geometrical points only have x,y,z coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="cartesian-3d" + ) + with pytest.raises(AttributeError, match=r'Invalid coordinate \("longitude"\)'): + new_point.longitude + with pytest.raises(AttributeError, match=r'Invalid coordinate \("latitude"\)'): + new_point.latitude + with pytest.raises(AttributeError, match=r'Invalid coordinate \("height"\)'): + new_point.height + + # Geographical points only have longitude, latitude, height coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + with pytest.raises(AttributeError, match=r'Invalid coordinate \("x"\)'): + new_point.x + with pytest.raises(AttributeError, match=r'Invalid coordinate \("y"\)'): + new_point.y + with pytest.raises(AttributeError, match=r'Invalid coordinate \("z"\)'): + new_point.z + + +def test_property_accessors_depending_on_crs_shapely_gte_2(): + """ + Tests that points are accessed via their respective accessors. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Check the version of Shapely installed to run the appropriate tests: + try: + from shapely import __version__ + except ImportError: + pytest.skip("Shapely not installed") + + if int("".join(__version__.split(".")[0:3])) < 200: + pytest.skip("Shapely < 2.0.0 is installed, skipping test") + # Geometrical points only have x,y,z coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="cartesian-3d" + ) + with pytest.raises(TypeError, match=r'Invalid coordinate \("longitude"\)'): + new_point.longitude + with pytest.raises(TypeError, match=r'Invalid coordinate \("latitude"\)'): + new_point.latitude + with pytest.raises(TypeError, match=r'Invalid coordinate \("height"\)'): + new_point.height + + # Geographical points only have longitude, latitude, height coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + with pytest.raises(TypeError, match=r'Invalid coordinate \("x"\)'): + new_point.x + with pytest.raises(TypeError, match=r'Invalid coordinate \("y"\)'): + new_point.y + with pytest.raises(TypeError, match=r'Invalid coordinate \("z"\)'): + new_point.z + + +def test_property_accessors(): + """ + Tests that points are accessed via their respective accessors and that these accessors return the right values. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Geometrical points + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 1.0, 2.0), crs="cartesian-3d" + ) + assert new_point.x == 0.0, "Expected x coordinate to be 0.0" + assert new_point.y == 1.0, "Expected y coordinate to be 1.0" + assert new_point.z == 2.0, "Expected z coordinate to be 2.0" + + # Geographical points + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 1.0, 2.0), crs="wgs-84-3d" + ) + assert new_point.longitude == 0.0, "Expected longitude to be 0.0" + assert new_point.latitude == 1.0, "Expected latitude to be 1.0" + assert new_point.height == 2.0, "Expected height to be 2.0" diff --git a/test/test_contrib/sync_/test_spatial_properties.py b/test/test_contrib/sync_/test_spatial_properties.py new file mode 100644 index 00000000..f33f4fb6 --- /dev/null +++ b/test/test_contrib/sync_/test_spatial_properties.py @@ -0,0 +1,289 @@ +""" +Provides a test case for issue 374 - "Support for Point property type". + +For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374 +""" + +import random +from test._async_compat import mark_sync_test + +import neo4j.spatial +import pytest + +import neomodel +import neomodel.contrib.spatial_properties + +from .test_spatial_datatypes import ( + basic_type_assertions, + check_and_skip_neo4j_least_version, +) + + +def test_spatial_point_property(): + """ + Tests that specific modes of instantiation fail as expected. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + with pytest.raises(ValueError, match=r"Invalid CRS\(None\)"): + a_point_property = neomodel.contrib.spatial_properties.PointProperty() + + with pytest.raises(ValueError, match=r"Invalid CRS\(crs_isaak\)"): + a_point_property = neomodel.contrib.spatial_properties.PointProperty( + crs="crs_isaak" + ) + + with pytest.raises(TypeError, match="Invalid default value"): + a_point_property = neomodel.contrib.spatial_properties.PointProperty( + default=(0.0, 0.0), crs="cartesian" + ) + + +def test_inflate(): + """ + Tests that the marshalling from neo4j to neomodel data types works as expected. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # The test is repeatable enough to try and standardise it. The same test is repeated with the assertions in + # `basic_type_assertions` and different messages to be able to localise the exception. + # + # Array of points to inflate and messages when things go wrong + values_from_db = [ + ( + neo4j.spatial.CartesianPoint((0.0, 0.0)), + "Expected Neomodel 2d cartesian point when inflating 2d cartesian neo4j point", + ), + ( + neo4j.spatial.CartesianPoint((0.0, 0.0, 0.0)), + "Expected Neomodel 3d cartesian point when inflating 3d cartesian neo4j point", + ), + ( + neo4j.spatial.WGS84Point((0.0, 0.0)), + "Expected Neomodel 2d geographical point when inflating 2d geographical neo4j point", + ), + ( + neo4j.spatial.WGS84Point((0.0, 0.0, 0.0)), + "Expected Neomodel 3d geographical point inflating 3d geographical neo4j point", + ), + ] + + # Run the above tests + for a_value in values_from_db: + expected_point = neomodel.contrib.spatial_properties.NeomodelPoint( + tuple(a_value[0]), + crs=neomodel.contrib.spatial_properties.SRID_TO_CRS[a_value[0].srid], + ) + inflated_point = neomodel.contrib.spatial_properties.PointProperty( + crs=neomodel.contrib.spatial_properties.SRID_TO_CRS[a_value[0].srid] + ).inflate(a_value[0]) + basic_type_assertions( + expected_point, + inflated_point, + "{}, received {}".format(a_value[1], inflated_point), + ) + + +def test_deflate(): + """ + Tests that the marshalling from neomodel to neo4j data types works as expected + :return: + """ + # Please see inline comments in `test_inflate`. This test function is 90% to that one with very minor differences. + # + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + CRS_TO_SRID = dict( + [ + (value, key) + for key, value in neomodel.contrib.spatial_properties.SRID_TO_CRS.items() + ] + ) + # Values to construct and expect during deflation + values_from_neomodel = [ + ( + neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="cartesian" + ), + "Expected Neo4J 2d cartesian point when deflating Neomodel 2d cartesian point", + ), + ( + neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="cartesian-3d" + ), + "Expected Neo4J 3d cartesian point when deflating Neomodel 3d cartesian point", + ), + ( + neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0), crs="wgs-84"), + "Expected Neo4J 2d geographical point when deflating Neomodel 2d geographical point", + ), + ( + neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ), + "Expected Neo4J 3d geographical point when deflating Neomodel 3d geographical point", + ), + ] + + # Run the above tests. + for a_value in values_from_neomodel: + expected_point = neo4j.spatial.Point(tuple(a_value[0].coords[0])) + expected_point.srid = CRS_TO_SRID[a_value[0].crs] + deflated_point = neomodel.contrib.spatial_properties.PointProperty( + crs=a_value[0].crs + ).deflate(a_value[0]) + basic_type_assertions( + expected_point, + deflated_point, + "{}, received {}".format(a_value[1], deflated_point), + check_neo4j_points=True, + ) + + +@mark_sync_test +def test_default_value(): + """ + Tests that the default value passing mechanism works as expected with NeomodelPoint values. + :return: + """ + + def get_some_point(): + return neomodel.contrib.spatial_properties.NeomodelPoint( + (random.random(), random.random()) + ) + + class LocalisableEntity(neomodel.StructuredNode): + """ + A very simple entity to try out the default value assignment. + """ + + identifier = neomodel.UniqueIdProperty() + location = neomodel.contrib.spatial_properties.PointProperty( + crs="cartesian", default=get_some_point + ) + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Save an object + an_object = LocalisableEntity().save() + coords = an_object.location.coords[0] + # Retrieve it + retrieved_object = LocalisableEntity.nodes.get(identifier=an_object.identifier) + # Check against an independently created value + assert ( + retrieved_object.location + == neomodel.contrib.spatial_properties.NeomodelPoint(coords) + ), ("Default value assignment failed.") + + +@mark_sync_test +def test_array_of_points(): + """ + Tests that Arrays of Points work as expected. + + :return: + """ + + class AnotherLocalisableEntity(neomodel.StructuredNode): + """ + A very simple entity with an array of locations + """ + + identifier = neomodel.UniqueIdProperty() + locations = neomodel.ArrayProperty( + neomodel.contrib.spatial_properties.PointProperty(crs="cartesian") + ) + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + an_object = AnotherLocalisableEntity( + locations=[ + neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), + neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), + ] + ).save() + + retrieved_object = AnotherLocalisableEntity.nodes.get( + identifier=an_object.identifier + ) + + assert ( + type(retrieved_object.locations) is list + ), "Array of Points definition failed." + assert retrieved_object.locations == [ + neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), + neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), + ], "Array of Points incorrect values." + + +@mark_sync_test +def test_simple_storage_retrieval(): + """ + Performs a simple Create, Retrieve via .save(), .get() which, due to the way Q objects operate, tests the + __copy__, __deepcopy__ operations of NeomodelPoint. + :return: + """ + + class TestStorageRetrievalProperty(neomodel.StructuredNode): + uid = neomodel.UniqueIdProperty() + description = neomodel.StringProperty() + location = neomodel.contrib.spatial_properties.PointProperty(crs="cartesian") + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + a_restaurant = TestStorageRetrievalProperty( + description="Milliways", + location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)), + ).save() + + a_property = TestStorageRetrievalProperty.nodes.get( + location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)) + ) + + assert a_restaurant.description == a_property.description + + +def test_equality_with_other_objects(): + """ + Performs equality tests and ensures tha ``NeomodelPoint`` can be compared with ShapelyPoint and NeomodelPoint only. + """ + try: + import shapely.geometry + from shapely import __version__ + except ImportError: + pytest.skip("Shapely module not present") + + if int("".join(__version__.split(".")[0:3])) < 200: + pytest.skip(f"Shapely 2.0 not present (Current version is {__version__}") + + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0) + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == shapely.geometry.Point((0, 0)) From fa6b38f9e80e1f8ed1d3e3d1069453989ba390f5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 17:06:59 +0100 Subject: [PATCH 32/73] Update pre-commit-config --- .pre-commit-config.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88cf1dbe..dd58a3c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,4 +9,8 @@ repos: name: unasync entry: bin/make-unasync language: system - files: "^(neomodel/async_|test/async_)/.*" \ No newline at end of file + files: "^(neomodel/async_|test/async_)/.*" + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black \ No newline at end of file From 6ac80941fdc4c4e59d834f3d749e1bd1d6f8b6e8 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Wed, 3 Jan 2024 17:16:57 +0100 Subject: [PATCH 33/73] Move test_contrib --- bin/make-unasync | 16 ++++++++-------- .../async_ => async_/test_contrib}/__init__.py | 0 .../test_contrib}/test_semi_structured.py | 0 .../test_contrib}/test_spatial_datatypes.py | 0 .../test_contrib}/test_spatial_properties.py | 0 .../sync_ => sync_/test_contrib}/__init__.py | 0 .../test_contrib}/test_semi_structured.py | 0 .../test_contrib}/test_spatial_datatypes.py | 0 .../test_contrib}/test_spatial_properties.py | 0 9 files changed, 8 insertions(+), 8 deletions(-) rename test/{test_contrib/async_ => async_/test_contrib}/__init__.py (100%) rename test/{test_contrib/async_ => async_/test_contrib}/test_semi_structured.py (100%) rename test/{test_contrib/async_ => async_/test_contrib}/test_spatial_datatypes.py (100%) rename test/{test_contrib/async_ => async_/test_contrib}/test_spatial_properties.py (100%) rename test/{test_contrib/sync_ => sync_/test_contrib}/__init__.py (100%) rename test/{test_contrib/sync_ => sync_/test_contrib}/test_semi_structured.py (100%) rename test/{test_contrib/sync_ => sync_/test_contrib}/test_spatial_datatypes.py (100%) rename test/{test_contrib/sync_ => sync_/test_contrib}/test_spatial_properties.py (100%) diff --git a/bin/make-unasync b/bin/make-unasync index 72a66074..734efc7c 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -17,8 +17,8 @@ ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_" -ASYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "async_" -SYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "sync_" +# ASYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "async_" +# SYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "sync_" UNASYNC_SUFFIX = ".unasync" PY_FILE_EXTENSIONS = {".py"} @@ -236,18 +236,18 @@ def apply_unasync(files): todir=str(SYNC_INTEGRATION_TEST_DIR), additional_replacements=additional_test_replacements, ), - CustomRule( - fromdir=str(ASYNC_INTEGRATION_TEST_CONTRIB_DIR), - todir=str(SYNC_INTEGRATION_TEST_CONTRIB_DIR), - additional_replacements=additional_test_replacements, - ), + # CustomRule( + # fromdir=str(ASYNC_INTEGRATION_TEST_CONTRIB_DIR), + # todir=str(SYNC_INTEGRATION_TEST_CONTRIB_DIR), + # additional_replacements=additional_test_replacements, + # ), ] if not files: paths = list(ASYNC_DIR.rglob("*")) paths += list(ASYNC_CONTRIB_DIR.rglob("*")) paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*")) - paths += list(ASYNC_INTEGRATION_TEST_CONTRIB_DIR.rglob("*")) + # paths += list(ASYNC_INTEGRATION_TEST_CONTRIB_DIR.rglob("*")) else: paths = [ROOT_DIR / Path(f) for f in files] filtered_paths = [] diff --git a/test/test_contrib/async_/__init__.py b/test/async_/test_contrib/__init__.py similarity index 100% rename from test/test_contrib/async_/__init__.py rename to test/async_/test_contrib/__init__.py diff --git a/test/test_contrib/async_/test_semi_structured.py b/test/async_/test_contrib/test_semi_structured.py similarity index 100% rename from test/test_contrib/async_/test_semi_structured.py rename to test/async_/test_contrib/test_semi_structured.py diff --git a/test/test_contrib/async_/test_spatial_datatypes.py b/test/async_/test_contrib/test_spatial_datatypes.py similarity index 100% rename from test/test_contrib/async_/test_spatial_datatypes.py rename to test/async_/test_contrib/test_spatial_datatypes.py diff --git a/test/test_contrib/async_/test_spatial_properties.py b/test/async_/test_contrib/test_spatial_properties.py similarity index 100% rename from test/test_contrib/async_/test_spatial_properties.py rename to test/async_/test_contrib/test_spatial_properties.py diff --git a/test/test_contrib/sync_/__init__.py b/test/sync_/test_contrib/__init__.py similarity index 100% rename from test/test_contrib/sync_/__init__.py rename to test/sync_/test_contrib/__init__.py diff --git a/test/test_contrib/sync_/test_semi_structured.py b/test/sync_/test_contrib/test_semi_structured.py similarity index 100% rename from test/test_contrib/sync_/test_semi_structured.py rename to test/sync_/test_contrib/test_semi_structured.py diff --git a/test/test_contrib/sync_/test_spatial_datatypes.py b/test/sync_/test_contrib/test_spatial_datatypes.py similarity index 100% rename from test/test_contrib/sync_/test_spatial_datatypes.py rename to test/sync_/test_contrib/test_spatial_datatypes.py diff --git a/test/test_contrib/sync_/test_spatial_properties.py b/test/sync_/test_contrib/test_spatial_properties.py similarity index 100% rename from test/test_contrib/sync_/test_spatial_properties.py rename to test/sync_/test_contrib/test_spatial_properties.py From 8787a51cffb4c1b4f536b461c5a4233d4379c60c Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 4 Jan 2024 09:33:54 +0100 Subject: [PATCH 34/73] Remove commented out code --- bin/make-unasync | 8 -------- 1 file changed, 8 deletions(-) diff --git a/bin/make-unasync b/bin/make-unasync index 734efc7c..5acdf5e7 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -17,8 +17,6 @@ ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_" -# ASYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "async_" -# SYNC_INTEGRATION_TEST_CONTRIB_DIR = ROOT_DIR / "test" / "test_contrib" / "sync_" UNASYNC_SUFFIX = ".unasync" PY_FILE_EXTENSIONS = {".py"} @@ -236,18 +234,12 @@ def apply_unasync(files): todir=str(SYNC_INTEGRATION_TEST_DIR), additional_replacements=additional_test_replacements, ), - # CustomRule( - # fromdir=str(ASYNC_INTEGRATION_TEST_CONTRIB_DIR), - # todir=str(SYNC_INTEGRATION_TEST_CONTRIB_DIR), - # additional_replacements=additional_test_replacements, - # ), ] if not files: paths = list(ASYNC_DIR.rglob("*")) paths += list(ASYNC_CONTRIB_DIR.rglob("*")) paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*")) - # paths += list(ASYNC_INTEGRATION_TEST_CONTRIB_DIR.rglob("*")) else: paths = [ROOT_DIR / Path(f) for f in files] filtered_paths = [] From 94b91709ee3d2818dd7447278db692af690a40e6 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 4 Jan 2024 11:44:08 +0100 Subject: [PATCH 35/73] Update doc for async --- doc/source/configuration.rst | 16 ++++----- doc/source/cypher.rst | 6 ++-- doc/source/extending.rst | 2 +- doc/source/getting_started.rst | 36 +++++++++++++++++-- doc/source/index.rst | 10 ++++++ doc/source/module_documentation.rst | 6 ++-- doc/source/module_documentation_sync.rst | 6 ++-- doc/source/queries.rst | 16 +++++++++ doc/source/transactions.rst | 6 ++-- neomodel/scripts/neomodel_inspect_database.py | 2 +- neomodel/scripts/neomodel_install_labels.py | 2 ++ neomodel/scripts/neomodel_remove_labels.py | 2 ++ 12 files changed, 84 insertions(+), 26 deletions(-) diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index ed1af11b..ba590be2 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -66,7 +66,7 @@ Change/Close the connection Optionally, you can change the connection at any time by calling ``set_connection``:: - from neomodel import db + from neomodel.sync_.core import db # Using URL - auto-managed db.set_connection(url='bolt://neo4j:neo4j@localhost:7687') @@ -78,7 +78,7 @@ The new connection url will be applied to the current thread or process. Since Neo4j version 5, driver auto-close is deprecated. Make sure to close the connection anytime you want to replace it, as well as at the end of your application's lifecycle by calling ``close_connection``:: - from neomodel import db + from neomodel.sync_.core import db db.close_connection() # If you then want a new connection @@ -119,14 +119,7 @@ with something like: :: Enable automatic index and constraint creation ---------------------------------------------- -After the definition of a `StructuredNode`, Neomodel can install the corresponding -constraints and indexes at compile time. However this method is only recommended for testing:: - - from neomodel import config - # before loading your node definitions - config.AUTO_INSTALL_LABELS = True - -Neomodel also provides the :ref:`neomodel_install_labels` script for this task, +Neomodel provides the :ref:`neomodel_install_labels` script for this task, however if you want to handle this manually see below. Install indexes and constraints for a single class:: @@ -146,6 +139,9 @@ Or for an entire 'schema' :: # + Creating unique constraint for name on label User for class yourapp.models.User # ... +.. note:: + config.AUTO_INSTALL_LABELS has been removed from neomodel in version 6.0 + Require timezones on DateTimeProperty ------------------------------------- diff --git a/doc/source/cypher.rst b/doc/source/cypher.rst index f8c7ccaf..ed8f422c 100644 --- a/doc/source/cypher.rst +++ b/doc/source/cypher.rst @@ -19,7 +19,7 @@ Stand alone Outside of a `StructuredNode`:: # for standalone queries - from neomodel import db + from neomodel.sync_.core import db results, meta = db.cypher_query(query, params, resolve_objects=True) The ``resolve_objects`` parameter automatically inflates the returned nodes to their defined classes (this is turned **off** by default). See :ref:`automatic_class_resolution` for details and possible pitfalls. @@ -40,7 +40,7 @@ First, you need to install pandas by yourself. We do not include it by default t You can use the `pandas` integration to return a `DataFrame` or `Series` object:: - from neomodel import db + from neomodel.sync_.core import db from neomodel.integration.pandas import to_dataframe, to_series df = to_dataframe(db.cypher_query("MATCH (a:Person) RETURN a.name AS name, a.born AS born")) @@ -59,7 +59,7 @@ First, you need to install numpy by yourself. We do not include it by default to You can use the `numpy` integration to return a `ndarray` object:: - from neomodel import db + from neomodel.sync_.core import db from neomodel.integration.numpy import to_ndarray array = to_ndarray(db.cypher_query("MATCH (a:Person) RETURN a.name AS name, a.born AS born")) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 32067009..df50f84d 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -39,7 +39,7 @@ labels, the `__optional_labels__` property must be defined as a list of strings: __optional_labels__ = ["SuperSaver", "SeniorDiscount"] balance = IntegerProperty(index=True) -.. warning:: The size of the node class mapping grows exponentially with optional labels. Use with some caution. +.. note:: The size of the node class mapping grows exponentially with optional labels. Use with some caution. Mixins diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 25e999e7..6aa8a421 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -22,6 +22,7 @@ Querying the graph neomodel is mainly used as an OGM (see next section), but you can also use it for direct Cypher queries : :: + from neomodel.sync_.core import db results, meta = db.cypher_query("RETURN 'Hello World' as message") @@ -104,7 +105,7 @@ and cardinality will be default (ZeroOrMore). Finally, relationship cardinality is guessed from the database by looking at existing relationships, so it might guess wrong on edge cases. -.. warning:: +.. note:: The script relies on the method apoc.meta.cypher.types to parse property types. So APOC must be installed on your Neo4j server for this script to work. @@ -246,7 +247,7 @@ the following syntax:: Person.nodes.all().fetch_relations('city__country', Optional('country')) -.. warning:: +.. note:: This feature is still a work in progress for extending path traversal and fecthing. It currently stops at returning the resolved objects as they are returned in Cypher. @@ -256,3 +257,34 @@ the following syntax:: If you want to go further in the resolution process, you have to develop your own parser (for now). + +Async neomodel +============== + +neomodel supports asynchronous operations using the async support of neo4j driver. The examples below take a few of the above examples, +but rewritten for async:: + + from neomodel.async_.core import adb + results, meta = await adb.cypher_query("RETURN 'Hello World' as message") + +OGM with async :: + + # Note that properties do not change, but nodes and relationships now have an Async prefix + from neomodel import (AsyncStructuredNode, StringProperty, IntegerProperty, + UniqueIdProperty, AsyncRelationshipTo) + + class Country(AsyncStructuredNode): + code = StringProperty(unique_index=True, required=True) + + class City(AsyncStructuredNode): + name = StringProperty(required=True) + country = AsyncRelationshipTo(Country, 'FROM_COUNTRY') + + # Operations that interact with the database are now async + # Return all nodes + all_nodes = await Country.nodes.all() + + # Relationships + germany = await Country(code='DE').save() + await jim.country.connect(germany) + diff --git a/doc/source/index.rst b/doc/source/index.rst index 397e838f..b4e6f1f7 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -9,6 +9,7 @@ An Object Graph Mapper (OGM) for the Neo4j_ graph database, built on the awesome - Enforce your schema through cardinality restrictions. - Full transaction support. - Thread safe. +- Async support. - pre/post save/delete hooks. - Django integration via django_neomodel_ @@ -40,6 +41,15 @@ To install from github:: $ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev +.. note:: + + **Breaking changes in 6.0** + + Introducing support for asynchronous programming to neomodel required to introduce some breaking changes: + + - Replace `from neomodel import db` with `from neomodel.sync_.core import db` + - config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` (:ref:`neomodel_install_labels`) command instead. + Contents ======== diff --git a/doc/source/module_documentation.rst b/doc/source/module_documentation.rst index 937060a5..364e207e 100644 --- a/doc/source/module_documentation.rst +++ b/doc/source/module_documentation.rst @@ -1,6 +1,6 @@ -========================== -Async/sync independent API -========================== +============ +General API +============ Properties ========== diff --git a/doc/source/module_documentation_sync.rst b/doc/source/module_documentation_sync.rst index 9eb642fe..5214a89a 100644 --- a/doc/source/module_documentation_sync.rst +++ b/doc/source/module_documentation_sync.rst @@ -1,6 +1,6 @@ -================= -API Documentation -================= +====================== +Sync API Documentation +====================== Core ==== diff --git a/doc/source/queries.rst b/doc/source/queries.rst index c3e37629..4c77a791 100644 --- a/doc/source/queries.rst +++ b/doc/source/queries.rst @@ -240,3 +240,19 @@ relationships to their relationship models *if such a model exists*. In other wo relationships with data (such as ``PersonLivesInCity`` above) will be instantiated to their respective objects or ``StrucuredRel`` otherwise. Relationships do not "reload" their end-points (unless this is required). + +Async neomodel - Caveats +======================== + +Python does not support async dunder methods. This means that we had to implement some overrides for those. +See the example below:: + + # This will not work as it uses the synchronous __bool__ method + assert await Customer.nodes.filter(prop="value") + + # Do this instead + assert await Customer.nodes.filter(prop="value").check_bool() + assert await Customer.nodes.filter(prop="value").check_nonzero() + + # Note : no changes are needed for sync so this still works : + assert Customer.nodes.filter(prop="value") diff --git a/doc/source/transactions.rst b/doc/source/transactions.rst index dfa97ee6..1ca80b08 100644 --- a/doc/source/transactions.rst +++ b/doc/source/transactions.rst @@ -13,7 +13,7 @@ Basic usage Transactions can be used via a context manager:: - from neomodel import db + from neomodel.sync_.core import db with db.transaction: Person(name='Bob').save() @@ -171,7 +171,7 @@ Impersonation Impersonation (`see Neo4j driver documentation ``) can be enabled via a context manager:: - from neomodel import db + from neomodel.sync_.core import db with db.impersonate(user="writeuser"): Person(name='Bob').save() @@ -186,7 +186,7 @@ or as a function decorator:: This can be mixed with other context manager like transactions:: - from neomodel import db + from neomodel.sync_.core import db @db.impersonate(user="tempuser") # Both transactions will be run as the same impersonated user diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index f5254e53..3147ebdf 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -1,7 +1,7 @@ """ .. _neomodel_inspect_database: -``_neomodel_inspect_database`` +``neomodel_inspect_database`` --------------------------- :: diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index 8e553396..8aa7a73b 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -14,6 +14,8 @@ If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + + Note : this script only has a synchronous mode. positional arguments: diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 14199b0b..79e79390 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -14,6 +14,8 @@ If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + + Note : this script only has a synchronous mode. options: -h, --help show this help message and exit From 3b76299c5786fa8d83f91836c6cbdaa113794d0a Mon Sep 17 00:00:00 2001 From: Jake Rosenfeld Date: Sat, 6 Jan 2024 14:03:48 +0100 Subject: [PATCH 36/73] centralized inflate function for nodes and relationships --- neomodel/contrib/semi_structured.py | 6 ++-- neomodel/core.py | 22 ++++-------- neomodel/properties.py | 33 ++++++++++++++++- neomodel/relationship.py | 10 +----- neomodel/relationship_manager.py | 4 +-- neomodel/util.py | 8 +++-- test/test_properties.py | 56 ++++++++++++++++++++++++----- 7 files changed, 97 insertions(+), 42 deletions(-) diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/semi_structured.py index 9c719983..ec803b05 100644 --- a/neomodel/contrib/semi_structured.py +++ b/neomodel/contrib/semi_structured.py @@ -1,6 +1,6 @@ +from neomodel import get_graph_entity_properties from neomodel.core import StructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict -from neomodel.util import _get_node_properties class SemiStructuredNode(StructuredNode): @@ -33,7 +33,7 @@ def inflate(cls, node): props = {} node_properties = {} for key, prop in cls.__all_properties__: - node_properties = _get_node_properties(node) + node_properties = get_graph_entity_properties(node) if key in node_properties: props[key] = prop.inflate(node_properties[key], node) elif prop.has_default: @@ -57,7 +57,7 @@ def inflate(cls, node): def deflate(cls, node_props, obj=None, skip_empty=False): deflated = super().deflate(node_props, obj, skip_empty=skip_empty) for key in [k for k in node_props if k not in deflated]: - if hasattr(cls, key) and (getattr(cls,key).required or not skip_empty): + if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): raise DeflateConflict(cls, key, deflated[key], obj.element_id) node_props.update(deflated) diff --git a/neomodel/core.py b/neomodel/core.py index 415a97af..a677e153 100644 --- a/neomodel/core.py +++ b/neomodel/core.py @@ -12,7 +12,12 @@ ) from neomodel.hooks import hooks from neomodel.properties import Property, PropertyManager -from neomodel.util import Database, _get_node_properties, _UnsavedNode, classproperty +from neomodel.util import ( + Database, + _UnsavedNode, + classproperty, + get_graph_entity_properties, +) db = Database() @@ -666,20 +671,7 @@ def inflate(cls, node): snode = cls() snode.element_id_property = node else: - node_properties = _get_node_properties(node) - props = {} - for key, prop in cls.__all_properties__: - # map property name from database to object property - db_property = prop.db_property or key - - if db_property in node_properties: - props[key] = prop.inflate(node_properties[db_property], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - - snode = cls(**props) + snode = super().inflate(node) snode.element_id_property = node.element_id return snode diff --git a/neomodel/properties.py b/neomodel/properties.py index e28b4ead..6d0df28c 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -85,7 +85,15 @@ def __properties__(self): @classmethod def deflate(cls, properties, obj=None, skip_empty=False): - # deflate dict ready to be stored + """ + Deflate the properties of a PropertyManager subclass (a user-defined StructuredNode or StructuredRel) so that it + can be put into a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) for storage. properties + can be constructed manually, or fetched from a PropertyManager subclass using __properties__. + + Includes mapping from python class attribute name -> database property name (see Property.db_property). + + Ignores any properties that are not defined as python attributes in the class definition. + """ deflated = {} for name, property in cls.defined_properties(aliases=False, rels=False).items(): db_property = property.db_property or name @@ -99,6 +107,29 @@ def deflate(cls, properties, obj=None, skip_empty=False): deflated[db_property] = None return deflated + @classmethod + def inflate(cls, graph_entity): + """ + Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance + of cls. + + Includes mapping from database property name (see Property.db_property) -> python class attribute name. + + Ignores any properties that are not defined as python attributes in the class definition. + """ + inflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.db_property or name + if db_property in graph_entity: + inflated[name] = property.inflate( + graph_entity[db_property], graph_entity + ) + elif property.has_default: + inflated[name] = property.default_value() + else: + inflated[name] = None + return cls(**inflated) + @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): from .relationship_manager import RelationshipDefinition diff --git a/neomodel/relationship.py b/neomodel/relationship.py index 8df56c47..96d9f556 100644 --- a/neomodel/relationship.py +++ b/neomodel/relationship.py @@ -158,15 +158,7 @@ def inflate(cls, rel): :param rel: :return: StructuredRel """ - props = {} - for key, prop in cls.defined_properties(aliases=False, rels=False).items(): - if key in rel: - props[key] = prop.inflate(rel[key], obj=rel) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - srel = cls(**props) + srel = super().inflate(rel) srel._start_node_element_id_property = rel.start_node.element_id srel._end_node_element_id_property = rel.end_node.element_id srel.element_id_property = rel.element_id diff --git a/neomodel/relationship_manager.py b/neomodel/relationship_manager.py index 1e9cf79e..3b9645a4 100644 --- a/neomodel/relationship_manager.py +++ b/neomodel/relationship_manager.py @@ -15,7 +15,7 @@ _rel_merge_helper, ) from .relationship import StructuredRel -from .util import _get_node_properties, enumerate_traceback +from .util import enumerate_traceback, get_graph_entity_properties # basestring python 3.x fallback try: @@ -231,7 +231,7 @@ def reconnect(self, old_node, new_node): {"old": old_node.element_id}, ) if result: - node_properties = _get_node_properties(result[0][0]) + node_properties = get_graph_entity_properties(result[0][0]) existing_properties = node_properties.keys() else: raise NotConnected("reconnect", self.source, old_node) diff --git a/neomodel/util.py b/neomodel/util.py index 74e88250..142f50b8 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -662,9 +662,11 @@ def __str__(self): return self.__repr__() -def _get_node_properties(node): - """Get the properties from a neo4j.vx.types.graph.Node object.""" - return node._properties +def get_graph_entity_properties(entity): + """ + Get the properties from a neo4j.graph.Entity (neo4j.graph.Node or neo4j.graph.Relationship) object. + """ + return entity._properties def enumerate_traceback(initial_frame): diff --git a/test/test_properties.py b/test/test_properties.py index 454ada26..be3d3543 100644 --- a/test/test_properties.py +++ b/test/test_properties.py @@ -3,7 +3,15 @@ from pytest import mark, raises from pytz import timezone -from neomodel import StructuredNode, config, db +from neomodel import ( + Relationship, + StructuredNode, + StructuredRel, + config, + db, + get_graph_entity_properties, +) +from neomodel.contrib import SemiStructuredNode from neomodel.exceptions import ( DeflateError, InflateError, @@ -23,7 +31,6 @@ StringProperty, UniqueIdProperty, ) -from neomodel.util import _get_node_properties config.AUTO_INSTALL_LABELS = True @@ -225,9 +232,19 @@ class DefaultTestValueThree(StructuredNode): assert x.uid == "123" +class TestDBNamePropertyRel(StructuredRel): + known_for = StringProperty(db_property="knownFor") + + +# This must be defined outside of the test, otherwise the `Relationship` definition cannot look up +# `TestDBNamePropertyNode` +class TestDBNamePropertyNode(StructuredNode): + name_ = StringProperty(db_property="name") + knows = Relationship("TestDBNamePropertyNode", "KNOWS", model=TestDBNamePropertyRel) + + def test_independent_property_name(): - class TestDBNamePropertyNode(StructuredNode): - name_ = StringProperty(db_property="name") + # -- test node -- x = TestDBNamePropertyNode() x.name_ = "jim" @@ -235,16 +252,37 @@ class TestDBNamePropertyNode(StructuredNode): # check database property name on low level results, meta = db.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" + assert "name_" not in node_properties - node_properties = _get_node_properties(results[0][0]) - assert not "name_" in node_properties + # check python class property name at a high level assert not hasattr(x, "name") assert hasattr(x, "name_") assert TestDBNamePropertyNode.nodes.filter(name_="jim").all()[0].name_ == x.name_ assert TestDBNamePropertyNode.nodes.get(name_="jim").name_ == x.name_ + # -- test relationship -- + + r = x.knows.connect(x) + r.known_for = "10 years" + r.save() + + # check database property name on low level + results, meta = db.cypher_query( + "MATCH (:TestDBNamePropertyNode)-[r:KNOWS]->(:TestDBNamePropertyNode) RETURN r" + ) + rel_properties = get_graph_entity_properties(results[0][0]) + assert rel_properties["knownFor"] == "10 years" + assert not "known_for" in node_properties + + # check python class property name at a high level + assert not hasattr(r, "knownFor") + assert hasattr(r, "known_for") + assert x.knows.relationship(x).known_for == r.known_for + + # -- cleanup -- + x.delete() @@ -260,7 +298,7 @@ class TestNode(StructuredNode): # check database property name on low level results, meta = db.cypher_query("MATCH (n:TestNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties @@ -407,7 +445,7 @@ class ConstrainedTestNode(StructuredNode): # check database property name on low level results, meta = db.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["unique_required_property"] == "unique and required" # delete node afterwards From 70ba08c51637458343982805c579a3e4bb3f2e7b Mon Sep 17 00:00:00 2001 From: Jake Rosenfeld Date: Sat, 6 Jan 2024 14:10:43 +0100 Subject: [PATCH 37/73] centralized function for getting correct db property name --- neomodel/core.py | 6 +++--- neomodel/match.py | 4 +++- neomodel/properties.py | 11 +++++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/neomodel/core.py b/neomodel/core.py index a677e153..46cbc211 100644 --- a/neomodel/core.py +++ b/neomodel/core.py @@ -192,7 +192,7 @@ def _create_relationship_constraint(relationship_type: str, property_name: str, def _install_node(cls, name, property, quiet, stdout): # Create indexes and constraints for node property - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if property.index: if not quiet: stdout.write( @@ -220,7 +220,7 @@ def _install_relationship(cls, relationship, quiet, stdout): for prop_name, property in relationship_cls.defined_properties( aliases=False, rels=False ).items(): - db_property = property.db_property or prop_name + db_property = property.get_db_property_name(prop_name) if property.index: if not quiet: stdout.write( @@ -451,7 +451,7 @@ def _build_merge_query( n_merge_labels = ":".join(cls.inherited_labels()) n_merge_prm = ", ".join( ( - f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" for p in cls.__required_properties__ ) ) diff --git a/neomodel/match.py b/neomodel/match.py index fb47f568..6994cfce 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -247,7 +247,9 @@ def process_filter_args(cls, kwargs): ) # map property to correct property name in the database - db_property = cls.defined_properties(rels=False)[prop].db_property or prop + db_property = cls.defined_properties(rels=False)[prop].get_db_property_name( + prop + ) output[db_property] = (operator, deflated_value) diff --git a/neomodel/properties.py b/neomodel/properties.py index 6d0df28c..39472d94 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -96,7 +96,7 @@ def deflate(cls, properties, obj=None, skip_empty=False): """ deflated = {} for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if properties.get(name) is not None: deflated[db_property] = property.deflate(properties[name], obj) elif property.has_default: @@ -119,7 +119,7 @@ def inflate(cls, graph_entity): """ inflated = {} for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if db_property in graph_entity: inflated[name] = property.inflate( graph_entity[db_property], graph_entity @@ -241,6 +241,13 @@ def default_value(self): return self.default raise ValueError("No default value specified") + def get_db_property_name(self, attribute_name): + """ + Returns the name that should be used for the property in the database. This is db_property if supplied upon + construction, otherwise the given attribute_name from the model is used. + """ + return self.db_property or attribute_name + @property def is_indexed(self): return self.unique_index or self.index From 3932c5fd33c6a0bad78d5f40d835d24f1d15db70 Mon Sep 17 00:00:00 2001 From: Jake Rosenfeld Date: Sat, 6 Jan 2024 15:33:00 +0100 Subject: [PATCH 38/73] reworked inflate/deflate of semi-structured node to respect db_property --- neomodel/contrib/semi_structured.py | 64 ++++++++++++----------- neomodel/core.py | 7 +-- test/test_contrib/test_semi_structured.py | 48 ++++++++++++++++- test/test_properties.py | 44 +++++++++++++--- 4 files changed, 118 insertions(+), 45 deletions(-) diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/semi_structured.py index ec803b05..9e93674d 100644 --- a/neomodel/contrib/semi_structured.py +++ b/neomodel/contrib/semi_structured.py @@ -1,6 +1,6 @@ -from neomodel import get_graph_entity_properties from neomodel.core import StructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict +from neomodel.util import get_graph_entity_properties class SemiStructuredNode(StructuredNode): @@ -25,40 +25,44 @@ def hello(self): @classmethod def inflate(cls, node): - # support lazy loading - if isinstance(node, str) or isinstance(node, int): - snode = cls() - snode.element_id_property = node - else: - props = {} - node_properties = {} - for key, prop in cls.__all_properties__: - node_properties = get_graph_entity_properties(node) - if key in node_properties: - props[key] = prop.inflate(node_properties[key], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - # handle properties not defined on the class - for free_key in (x for x in node_properties if x not in props): - if hasattr(cls, free_key): - raise InflateConflict( - cls, free_key, node_properties[free_key], node.element_id - ) - props[free_key] = node_properties[free_key] + # Inflate all properties registered in the class definition + snode = super().inflate(node) - snode = cls(**props) - snode.element_id_property = node.element_id + # Node can be a string or int for lazy loading (See StructuredNode.inflate). In that case, `node` has nothing + # that can be unpacked further. + if not hasattr(node, "items"): + return snode + + # Inflate all extra properties not registered in the class definition + registered_db_property_names = { + property.get_db_property_name(name) + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + } + extra_keys = node.keys() - registered_db_property_names + for extra_key in extra_keys: + value = node[extra_key] + if hasattr(cls, extra_key): + raise InflateConflict(cls, extra_key, value, snode.element_id) + setattr(snode, extra_key, value) return snode @classmethod def deflate(cls, node_props, obj=None, skip_empty=False): + # Deflate all properties registered in the class definition deflated = super().deflate(node_props, obj, skip_empty=skip_empty) - for key in [k for k in node_props if k not in deflated]: - if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): - raise DeflateConflict(cls, key, deflated[key], obj.element_id) - node_props.update(deflated) - return node_props + # Deflate all extra properties not registered in the class definition + registered_names = cls.defined_properties(aliases=False, rels=False).keys() + extra_keys = node_props.keys() - registered_names + for extra_key in extra_keys: + value = node_props[extra_key] + if hasattr(cls, extra_key): + raise DeflateConflict( + cls, extra_key, value, node_props.get("element_id") + ) + deflated[extra_key] = node_props[extra_key] + + return deflated diff --git a/neomodel/core.py b/neomodel/core.py index 46cbc211..2be6558a 100644 --- a/neomodel/core.py +++ b/neomodel/core.py @@ -12,12 +12,7 @@ ) from neomodel.hooks import hooks from neomodel.properties import Property, PropertyManager -from neomodel.util import ( - Database, - _UnsavedNode, - classproperty, - get_graph_entity_properties, -) +from neomodel.util import Database, _UnsavedNode, classproperty db = Database() diff --git a/test/test_contrib/test_semi_structured.py b/test/test_contrib/test_semi_structured.py index fe73a2bd..19809647 100644 --- a/test/test_contrib/test_semi_structured.py +++ b/test/test_contrib/test_semi_structured.py @@ -1,4 +1,13 @@ -from neomodel import IntegerProperty, StringProperty +import neo4j.graph +import pytest + +from neomodel import ( + DeflateConflict, + InflateConflict, + IntegerProperty, + StringProperty, + db, +) from neomodel.contrib import SemiStructuredNode @@ -28,3 +37,40 @@ def test_save_to_model_with_extras(): def test_save_empty_model(): dummy = Dummy() assert dummy.save() + + +def test_inflate_conflict(): + class PersonForInflateTest(SemiStructuredNode): + name = StringProperty() + age = IntegerProperty() + + def hello(self): + print("Hi my names " + self.name) + + # An ok model + props = {"name": "Jim", "age": 8, "weight": 11} + db.cypher_query("CREATE (n:PersonForInflateTest $props)", {"props": props}) + jim = PersonForInflateTest.nodes.get(name="Jim") + assert jim.name == "Jim" + assert jim.age == 8 + assert jim.weight == 11 + + # A model that conflicts on `hello` + props = {"name": "Tim", "age": 8, "hello": "goodbye"} + db.cypher_query("CREATE (n:PersonForInflateTest $props)", {"props": props}) + with pytest.raises(InflateConflict): + PersonForInflateTest.nodes.get(name="Tim") + + +def test_deflate_conflict(): + class PersonForDeflateTest(SemiStructuredNode): + name = StringProperty() + age = IntegerProperty() + + def hello(self): + print("Hi my names " + self.name) + + tim = PersonForDeflateTest(name="Tim", age=8, weight=11).save() + tim.hello = "Hi" + with pytest.raises(DeflateConflict): + tim.save() diff --git a/test/test_properties.py b/test/test_properties.py index be3d3543..579beb36 100644 --- a/test/test_properties.py +++ b/test/test_properties.py @@ -3,14 +3,7 @@ from pytest import mark, raises from pytz import timezone -from neomodel import ( - Relationship, - StructuredNode, - StructuredRel, - config, - db, - get_graph_entity_properties, -) +from neomodel import Relationship, StructuredNode, StructuredRel, config, db from neomodel.contrib import SemiStructuredNode from neomodel.exceptions import ( DeflateError, @@ -31,6 +24,7 @@ StringProperty, UniqueIdProperty, ) +from neomodel.util import get_graph_entity_properties config.AUTO_INSTALL_LABELS = True @@ -286,6 +280,40 @@ def test_independent_property_name(): x.delete() +def test_independent_property_name_for_semi_structured(): + class TestDBNamePropertySemiStructuredNode(SemiStructuredNode): + title_ = StringProperty(db_property="title") + + semi = TestDBNamePropertySemiStructuredNode(title_="sir", extra="data") + semi.save() + + # check database property name on low level + results, meta = db.cypher_query( + "MATCH (n:TestDBNamePropertySemiStructuredNode) RETURN n" + ) + node_properties = get_graph_entity_properties(results[0][0]) + assert node_properties["title"] == "sir" + assert not "title_" in node_properties + assert node_properties["extra"] == "data" + + # check python class property name at a high level + assert hasattr(semi, "title_") + assert not hasattr(semi, "title") + assert hasattr(semi, "extra") + from_filter = TestDBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all()[ + 0 + ] + assert from_filter.title_ == "sir" + assert not hasattr(from_filter, "title") + assert from_filter.extra == "data" + from_get = TestDBNamePropertySemiStructuredNode.nodes.get(title_="sir") + assert from_get.title_ == "sir" + assert not hasattr(from_get, "title") + assert from_get.extra == "data" + + semi.delete() + + def test_independent_property_name_get_or_create(): class TestNode(StructuredNode): uid = UniqueIdProperty() From cc91226693d3ac6a193c57188fce891b81c42694 Mon Sep 17 00:00:00 2001 From: Giovanni Savarese Date: Wed, 10 Jan 2024 17:15:54 +0100 Subject: [PATCH 39/73] Add wraps to wrapper defined in TransactionProxy --- neomodel/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/neomodel/util.py b/neomodel/util.py index 74e88250..0b5d5b68 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -3,6 +3,7 @@ import sys import time import warnings +from functools import wraps from threading import local from typing import Optional, Sequence from urllib.parse import quote, unquote, urlparse @@ -578,6 +579,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.last_bookmark = self.db.commit() def __call__(self, func): + @wraps(func) def wrapper(*args, **kwargs): with self: return func(*args, **kwargs) From 885b4b1e5cadf68b84be1861ec53d43a5e6af2c7 Mon Sep 17 00:00:00 2001 From: OlehC Date: Tue, 16 Jan 2024 01:32:18 +0200 Subject: [PATCH 40/73] fix(core): compare objects by id in memory if they are different github: [https://github.com/neo4j-contrib/neomodel/issues/778](Structured Node __eq__ works incorrectly for non-persisted nodes.) --- neomodel/core.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/neomodel/core.py b/neomodel/core.py index 415a97af..71c0dbe9 100644 --- a/neomodel/core.py +++ b/neomodel/core.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import sys import warnings from itertools import combinations +from typing import Any from neo4j.exceptions import ClientError @@ -378,12 +381,22 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def __eq__(self, other): + def __eq__(self, other: StructuredNode | Any) -> bool: + """ + Compare two node objects. + + If both nodes were persisted, compare them by their element_id. + Otherwise, compare them using object id in memory. + + If `other` is not a node, always return False. + """ if not isinstance(other, (StructuredNode,)): return False - if hasattr(self, "element_id") and hasattr(other, "element_id"): + + if self.was_persisted and other.was_persisted: return self.element_id == other.element_id - return False + + return id(self) == id(other) def __ne__(self, other): return not self.__eq__(other) @@ -427,6 +440,13 @@ def id(self): "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." ) + @property + def was_persisted(self) -> bool: + """ + Shows status of node in the database. False, if node hasn't been saved yet, True otherwise. + """ + return self.element_id is not None + # methods @classmethod From d06d15f6dcffc4a935aebb85423ffe2f845178e9 Mon Sep 17 00:00:00 2001 From: Isaias Caporusso Date: Sat, 17 Feb 2024 04:25:03 -0300 Subject: [PATCH 41/73] feat(match): allow filtering by IN the ArrayProperty --- neomodel/match.py | 28 ++++++++++++++++++++++++---- test/test_match_api.py | 21 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/neomodel/match.py b/neomodel/match.py index fb47f568..961b5789 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -7,7 +7,7 @@ from .core import StructuredNode, db from .exceptions import MultipleNodesReturned from .match_q import Q, QBase -from .properties import AliasProperty +from .properties import AliasProperty, ArrayProperty OUTGOING, INCOMING, EITHER = 1, -1, 0 @@ -150,6 +150,7 @@ def _rel_merge_helper( # special operators _SPECIAL_OPERATOR_IN = "IN" +_SPECIAL_OPERATOR_ARRAY_IN = "any(x IN {ident}.{prop} WHERE x IN {val})" _SPECIAL_OPERATOR_INSENSITIVE = "(?i)" _SPECIAL_OPERATOR_ISNULL = "IS NULL" _SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL" @@ -261,7 +262,11 @@ def transform_operator_to_filter(operator, filter_key, filter_value, property_ob raise ValueError( f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" ) - deflated_value = [property_obj.deflate(v) for v in filter_value] + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(filter_value) + operator = _SPECIAL_OPERATOR_ARRAY_IN + else: + deflated_value = [property_obj.deflate(v) for v in filter_value] elif operator == _SPECIAL_OPERATOR_ISNULL: if not isinstance(filter_value, bool): raise ValueError( @@ -572,7 +577,14 @@ def _parse_q_filters(self, ident, q, source_class): statement = f"{ident}.{prop} {operator}" else: place_holder = self._register_place_holder(ident + "_" + prop) - statement = f"{ident}.{prop} {operator} ${place_holder}" + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + else: + statement = f"{ident}.{prop} {operator} ${place_holder}" self._query_params[place_holder] = val target.append(statement) ret = f" {q.connector} ".join(target) @@ -607,7 +619,15 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): ) else: place_holder = self._register_place_holder(ident + "_" + prop) - statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}" + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + statement = f"{'NOT' if negate else ''} {statement}" + else: + statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}" self._query_params[place_holder] = val stmts.append(statement) diff --git a/test/test_match_api.py b/test/test_match_api.py index ee6b337e..2d7d1657 100644 --- a/test/test_match_api.py +++ b/test/test_match_api.py @@ -4,6 +4,7 @@ from neomodel import ( INCOMING, + ArrayProperty, DateTimeProperty, IntegerProperty, Q, @@ -31,6 +32,7 @@ class Supplier(StructuredNode): class Species(StructuredNode): name = StringProperty() coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) + tags = ArrayProperty(StringProperty(), default=list) class Coffee(StructuredNode): @@ -502,3 +504,22 @@ def test_fetch_relations(): assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( name="Sainsburys" ) + + +def test_in_filter_with_array_property(): + tags = ["smoother", "sweeter", "chocolate", "sugar"] + no_match = ["organic"] + arabica = Species(name="Arabica", tags=tags).save() + + assert arabica in Species.nodes.filter( + tags__in=tags + ), "Species not found by tags given" + assert arabica in Species.nodes.filter( + Q(tags__in=tags) + ), "Species not found with Q by tags given" + assert arabica not in Species.nodes.filter( + ~Q(tags__in=tags) + ), "Species found by tags given in negated query" + assert arabica not in Species.nodes.filter( + tags__in=no_match + ), "Species found by tags with not match tags given" From 196ace2a992eab86d368645ac3e866894cb52654 Mon Sep 17 00:00:00 2001 From: Isaias Caporusso Date: Sat, 17 Feb 2024 09:12:01 -0300 Subject: [PATCH 42/73] fix(match): remove cognitive complexity of the transform_operator_to_filter --- neomodel/match.py | 108 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 85 insertions(+), 23 deletions(-) diff --git a/neomodel/match.py b/neomodel/match.py index 961b5789..5dfeeb08 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -255,33 +255,95 @@ def process_filter_args(cls, kwargs): return output +def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): + """ + Transform in operator to a cypher filter + + Args: + operator (str): operator to transform + filter_key (str): filter key + filter_value (str): filter value + property_obj (object): property object + + Returns: + tuple: operator, deflated_value + """ + if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): + raise ValueError( + f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" + ) + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(filter_value) + operator = _SPECIAL_OPERATOR_ARRAY_IN + else: + deflated_value = [property_obj.deflate(v) for v in filter_value] + + return operator, deflated_value + + +def transform_null_operator_to_filter(filter_key, filter_value): + """ + Transform null operator to a cypher filter + + Args: + filter_key (str): filter key + filter_value (str): filter value + + Returns: + tuple: operator, deflated_value + """ + if not isinstance(filter_value, bool): + raise ValueError(f"Value must be a bool for isnull operation on {filter_key}") + operator = "IS NULL" if filter_value else "IS NOT NULL" + deflated_value = None + return operator, deflated_value + + +def transform_regex_operator_to_filter( + operator, filter_key, filter_value, property_obj +): + """ + Transform regex operator to a cypher filter + + Args: + operator (str): operator to transform + filter_key (str): filter key + filter_value (str): filter value + property_obj (object): property object + + Returns: + tuple: operator, deflated_value + """ + + deflated_value = property_obj.deflate(filter_value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {filter_key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + return operator, deflated_value + + def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): - # handle special operators if operator == _SPECIAL_OPERATOR_IN: - if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" - ) - if isinstance(property_obj, ArrayProperty): - deflated_value = property_obj.deflate(filter_value) - operator = _SPECIAL_OPERATOR_ARRAY_IN - else: - deflated_value = [property_obj.deflate(v) for v in filter_value] + operator, deflated_value = transform_in_operator_to_filter( + operator=operator, + filter_key=filter_key, + filter_value=filter_value, + property_obj=property_obj, + ) elif operator == _SPECIAL_OPERATOR_ISNULL: - if not isinstance(filter_value, bool): - raise ValueError( - f"Value must be a bool for isnull operation on {filter_key}" - ) - operator = "IS NULL" if filter_value else "IS NOT NULL" - deflated_value = None + operator, deflated_value = transform_null_operator_to_filter( + filter_key=filter_key, filter_value=filter_value + ) elif operator in _REGEX_OPERATOR_TABLE.values(): - deflated_value = property_obj.deflate(filter_value) - if not isinstance(deflated_value, str): - raise ValueError(f"Must be a string value for {filter_key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX + operator, deflated_value = transform_regex_operator_to_filter( + operator=operator, + filter_key=filter_key, + filter_value=filter_value, + property_obj=property_obj, + ) else: deflated_value = property_obj.deflate(filter_value) From d09e17e38485f76676fe8b58aca70ea804e305fd Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 4 Mar 2024 16:32:52 +0100 Subject: [PATCH 43/73] Fix await MyLabel.nodes --- neomodel/async_/match.py | 3 +++ neomodel/async_/relationship_manager.py | 3 +++ neomodel/sync_/match.py | 3 +++ neomodel/sync_/relationship_manager.py | 3 +++ test/async_/test_match_api.py | 30 +++++++++++++------------ test/sync_/test_match_api.py | 30 +++++++++++++------------ 6 files changed, 44 insertions(+), 28 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 9b5769df..5ba1ce89 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -824,6 +824,9 @@ def __init__(self, source): self.relations_to_fetch: list = [] + def __await__(self): + return self.all().__await__() + async def _get(self, limit=None, lazy=False, **kwargs): self.filter(**kwargs) if limit: diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 5bfe8dc0..145e4247 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -69,6 +69,9 @@ def __str__(self): return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" + async def __await__(self): + return self.all().__await__() + def _check_node(self, obj): """check for valid node i.e correct class and is saved""" if not issubclass(type(obj), self.definition["node_class"]): diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 0609aa05..7c0abb16 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -822,6 +822,9 @@ def __init__(self, source): self.relations_to_fetch: list = [] + def __await__(self): + return self.all().__await__() + def _get(self, limit=None, lazy=False, **kwargs): self.filter(**kwargs) if limit: diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 1d31c2ca..683efe6f 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -64,6 +64,9 @@ def __str__(self): return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" + def __await__(self): + return self.all().__await__() + def _check_node(self, obj): """check for valid node i.e correct class and is saved""" if not issubclass(type(obj), self.definition["node_class"]): diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 4b96875c..8d05f09f 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -53,10 +53,6 @@ class Extension(AsyncStructuredNode): extension = AsyncRelationshipTo("Extension", "extension") -# TODO : Maybe split these tests into separate async and sync (not transpiled) -# That would allow to test "Coffee.nodes" for sync instead of Coffee.nodes.all() - - @mark_async_test async def test_filter_exclude_via_labels(): await Coffee(name="Java", price=99).save() @@ -174,18 +170,18 @@ async def test_len_and_iter_and_bool(): await Coffee(name="Icelands finest").save() - for c in await Coffee.nodes.all(): + for c in await Coffee.nodes: iterations += 1 await c.delete() assert iterations > 0 - assert len(await Coffee.nodes.all()) == 0 + assert len(await Coffee.nodes) == 0 @mark_async_test async def test_slice(): - for c in await Coffee.nodes.all(): + for c in await Coffee.nodes: await c.delete() await Coffee(name="Icelands finest").save() @@ -212,6 +208,8 @@ async def test_issue_208(): await b.suppliers.connect(l, {"courier": "fedex"}) await b.suppliers.connect(a, {"courier": "dhl"}) + # TODO : Find a way to not need the .all() here + # Note : Check AsyncTraversal match assert len(await b.suppliers.match(courier="fedex").all()) assert len(await b.suppliers.match(courier="dhl").all()) @@ -221,15 +219,17 @@ async def test_issue_589(): node1 = await Extension().save() node2 = await Extension().save() await node1.extension.connect(node2) + # TODO : Find a way to not need the .all() here + # Note : Check AsyncRelationshipDefinition (parent of AsyncRelationshipTo / From) assert node2 in await node1.extension.all() -# TODO : Fix the ValueError not raised @mark_async_test async def test_contains(): expensive = await Coffee(price=1000, name="Pricey").save() asda = await Coffee(name="Asda", price=1).save() + # TODO : Find a way to not need the .all() here assert expensive in await Coffee.nodes.filter(price__gt=999).all() assert asda not in await Coffee.nodes.filter(price__gt=999).all() @@ -244,17 +244,18 @@ async def test_contains(): @mark_async_test async def test_order_by(): - for c in await Coffee.nodes.all(): + for c in await Coffee.nodes: await c.delete() c1 = await Coffee(name="Icelands finest", price=5).save() c2 = await Coffee(name="Britains finest", price=10).save() c3 = await Coffee(name="Japans finest", price=35).save() - assert Coffee.nodes.order_by("price")[0].price == 5 - assert Coffee.nodes.order_by("-price")[0].price == 35 + assert (await Coffee.nodes.order_by("price")[0]).price == 5 + assert (await Coffee.nodes.order_by("-price")[0]).price == 35 ns = await Coffee.nodes.order_by("-price") + # TODO : Method fails qb = AsyncQueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) @@ -285,7 +286,7 @@ async def test_order_by(): @mark_async_test async def test_extra_filters(): - for c in await Coffee.nodes.all(): + for c in await Coffee.nodes: await c.delete() c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() @@ -293,6 +294,7 @@ async def test_extra_filters(): c3 = await Coffee(name="Japans finest", price=35, id_=3).save() c4 = await Coffee(name="US extra-fine", price=None, id_=4).save() + # TODO : Remove some .all() when filter is updated coffees_5_10 = await Coffee.nodes.filter(price__in=[10, 5]).all() assert len(coffees_5_10) == 2, "unexpected number of results" assert c1 in coffees_5_10, "doesnt contain 5 price coffee" @@ -360,7 +362,7 @@ async def test_empty_filters(): ``NodeSet`` object. """ - for c in await Coffee.nodes.all(): + for c in await Coffee.nodes: await c.delete() c1 = await Coffee(name="Super", price=5, id_=1).save() @@ -387,7 +389,7 @@ async def test_empty_filters(): @mark_async_test async def test_q_filters(): # Test where no children and self.connector != conn ? - for c in await Coffee.nodes.all(): + for c in await Coffee.nodes: await c.delete() c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 7dd3d489..791a1fa2 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -46,10 +46,6 @@ class Extension(StructuredNode): extension = RelationshipTo("Extension", "extension") -# TODO : Maybe split these tests into separate async and sync (not transpiled) -# That would allow to test "Coffee.nodes" for sync instead of Coffee.nodes.all() - - @mark_sync_test def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() @@ -165,18 +161,18 @@ def test_len_and_iter_and_bool(): Coffee(name="Icelands finest").save() - for c in Coffee.nodes.all(): + for c in Coffee.nodes: iterations += 1 c.delete() assert iterations > 0 - assert len(Coffee.nodes.all()) == 0 + assert len(Coffee.nodes) == 0 @mark_sync_test def test_slice(): - for c in Coffee.nodes.all(): + for c in Coffee.nodes: c.delete() Coffee(name="Icelands finest").save() @@ -203,6 +199,8 @@ def test_issue_208(): b.suppliers.connect(l, {"courier": "fedex"}) b.suppliers.connect(a, {"courier": "dhl"}) + # TODO : Find a way to not need the .all() here + # Note : Check AsyncTraversal match assert len(b.suppliers.match(courier="fedex").all()) assert len(b.suppliers.match(courier="dhl").all()) @@ -212,15 +210,17 @@ def test_issue_589(): node1 = Extension().save() node2 = Extension().save() node1.extension.connect(node2) + # TODO : Find a way to not need the .all() here + # Note : Check AsyncRelationshipDefinition (parent of AsyncRelationshipTo / From) assert node2 in node1.extension.all() -# TODO : Fix the ValueError not raised @mark_sync_test def test_contains(): expensive = Coffee(price=1000, name="Pricey").save() asda = Coffee(name="Asda", price=1).save() + # TODO : Find a way to not need the .all() here assert expensive in Coffee.nodes.filter(price__gt=999).all() assert asda not in Coffee.nodes.filter(price__gt=999).all() @@ -235,17 +235,18 @@ def test_contains(): @mark_sync_test def test_order_by(): - for c in Coffee.nodes.all(): + for c in Coffee.nodes: c.delete() c1 = Coffee(name="Icelands finest", price=5).save() c2 = Coffee(name="Britains finest", price=10).save() c3 = Coffee(name="Japans finest", price=35).save() - assert Coffee.nodes.order_by("price")[0].price == 5 - assert Coffee.nodes.order_by("-price")[0].price == 35 + assert (Coffee.nodes.order_by("price")[0]).price == 5 + assert (Coffee.nodes.order_by("-price")[0]).price == 35 ns = Coffee.nodes.order_by("-price") + # TODO : Method fails qb = QueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) @@ -276,7 +277,7 @@ def test_order_by(): @mark_sync_test def test_extra_filters(): - for c in Coffee.nodes.all(): + for c in Coffee.nodes: c.delete() c1 = Coffee(name="Icelands finest", price=5, id_=1).save() @@ -284,6 +285,7 @@ def test_extra_filters(): c3 = Coffee(name="Japans finest", price=35, id_=3).save() c4 = Coffee(name="US extra-fine", price=None, id_=4).save() + # TODO : Remove some .all() when filter is updated coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]).all() assert len(coffees_5_10) == 2, "unexpected number of results" assert c1 in coffees_5_10, "doesnt contain 5 price coffee" @@ -351,7 +353,7 @@ def test_empty_filters(): ``NodeSet`` object. """ - for c in Coffee.nodes.all(): + for c in Coffee.nodes: c.delete() c1 = Coffee(name="Super", price=5, id_=1).save() @@ -378,7 +380,7 @@ def test_empty_filters(): @mark_sync_test def test_q_filters(): # Test where no children and self.connector != conn ? - for c in Coffee.nodes.all(): + for c in Coffee.nodes: c.delete() c1 = Coffee(name="Icelands finest", price=5, id_=1).save() From 7025cdbba42740a92d8f530cae421a4a92d4f41e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 4 Mar 2024 16:38:46 +0100 Subject: [PATCH 44/73] Update tests --- test/async_/test_batch.py | 2 +- test/async_/test_models.py | 2 +- test/async_/test_relationships.py | 1 - test/async_/test_transactions.py | 24 ++++++++++++------------ test/sync_/test_batch.py | 2 +- test/sync_/test_models.py | 2 +- test/sync_/test_relationships.py | 1 - test/sync_/test_transactions.py | 24 ++++++++++++------------ 8 files changed, 28 insertions(+), 30 deletions(-) diff --git a/test/async_/test_batch.py b/test/async_/test_batch.py index 3b5d76d1..a1d86e21 100644 --- a/test/async_/test_batch.py +++ b/test/async_/test_batch.py @@ -93,7 +93,7 @@ async def test_batch_validation(): @mark_async_test async def test_batch_index_violation(): - for u in await Customer.nodes.all(): + for u in await Customer.nodes: await u.delete() users = await Customer.create( diff --git a/test/async_/test_models.py b/test/async_/test_models.py index 39b9026b..f3c922a3 100644 --- a/test/async_/test_models.py +++ b/test/async_/test_models.py @@ -177,7 +177,7 @@ async def test_not_updated_on_unique_error(): test.email = "jim@bob.com" with raises(UniqueProperty): await test.save() - customers = await Customer2.nodes.all() + customers = await Customer2.nodes assert customers[0].email != customers[1].email assert (await Customer2.nodes.get(email="jim@bob.com")).age == 7 assert (await Customer2.nodes.get(email="jim1@bob.com")).age == 2 diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py index f1ae8fdd..1bbac2a1 100644 --- a/test/async_/test_relationships.py +++ b/test/async_/test_relationships.py @@ -110,7 +110,6 @@ async def test_either_direction_connect(): assert isinstance(rels[0], AsyncStructuredRel) -# TODO : Make async-independent test to test .filter and not .filter.all() ? @mark_async_test async def test_search_and_filter_and_exclude(): fred = await PersonWithRels(name="Fred", age=13).save() diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index 39997d67..331e67f2 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -15,7 +15,7 @@ class APerson(AsyncStructuredNode): @mark_async_test async def test_rollback_and_commit_transaction(): - for p in await APerson.nodes.all(): + for p in await APerson.nodes: await p.delete() await APerson(name="Roger").save() @@ -24,13 +24,13 @@ async def test_rollback_and_commit_transaction(): await APerson(name="Terry S").save() await adb.rollback() - assert len(await APerson.nodes.all()) == 1 + assert len(await APerson.nodes) == 1 await adb.begin() await APerson(name="Terry S").save() await adb.commit() - assert len(await APerson.nodes.all()) == 2 + assert len(await APerson.nodes) == 2 @adb.transaction @@ -43,7 +43,7 @@ async def in_a_tx(*names): @mark_async_test async def test_transaction_decorator(): await adb.install_labels(APerson) - for p in await APerson.nodes.all(): + for p in await APerson.nodes: await p.delete() # should work @@ -54,7 +54,7 @@ async def test_transaction_decorator(): with raises(UniqueProperty): await in_a_tx("Jim", "Roger") - assert "Jim" not in [p.name async for p in await APerson.nodes.all()] + assert "Jim" not in [p.name async for p in await APerson.nodes] @mark_async_test @@ -71,14 +71,14 @@ async def test_transaction_as_a_context(): @mark_async_test async def test_query_inside_transaction(): - for p in await APerson.nodes.all(): + for p in await APerson.nodes: await p.delete() with adb.transaction: await APerson(name="Alice").save() await APerson(name="Bob").save() - assert len([p.name for p in await APerson.nodes.all()]) == 2 + assert len([p.name for p in await APerson.nodes]) == 2 @mark_async_test @@ -86,7 +86,7 @@ async def test_read_transaction(): await APerson(name="Johnny").save() with adb.read_transaction: - people = await APerson.nodes.all() + people = await APerson.nodes assert people with raises(TransactionError): @@ -122,7 +122,7 @@ async def in_a_tx(*names): @mark_async_test async def test_bookmark_transaction_decorator(): - for p in await APerson.nodes.all(): + for p in await APerson.nodes: await p.delete() # should work @@ -134,7 +134,7 @@ async def test_bookmark_transaction_decorator(): with raises(UniqueProperty): await in_a_tx("Jane", "Ruth") - assert "Jane" not in [p.name for p in await APerson.nodes.all()] + assert "Jane" not in [p.name for p in await APerson.nodes] @mark_async_test @@ -184,13 +184,13 @@ async def test_bookmark_passed_in_to_context(spy_on_db_begin): @mark_async_test async def test_query_inside_bookmark_transaction(): - for p in await APerson.nodes.all(): + for p in await APerson.nodes: await p.delete() with adb.transaction as transaction: await APerson(name="Alice").save() await APerson(name="Bob").save() - assert len([p.name for p in await APerson.nodes.all()]) == 2 + assert len([p.name for p in await APerson.nodes]) == 2 assert isinstance(transaction.last_bookmark, Bookmarks) diff --git a/test/sync_/test_batch.py b/test/sync_/test_batch.py index ca28626f..3eabe65e 100644 --- a/test/sync_/test_batch.py +++ b/test/sync_/test_batch.py @@ -89,7 +89,7 @@ def test_batch_validation(): @mark_sync_test def test_batch_index_violation(): - for u in Customer.nodes.all(): + for u in Customer.nodes: u.delete() users = Customer.create( diff --git a/test/sync_/test_models.py b/test/sync_/test_models.py index dc3ff735..89667f56 100644 --- a/test/sync_/test_models.py +++ b/test/sync_/test_models.py @@ -177,7 +177,7 @@ def test_not_updated_on_unique_error(): test.email = "jim@bob.com" with raises(UniqueProperty): test.save() - customers = Customer2.nodes.all() + customers = Customer2.nodes assert customers[0].email != customers[1].email assert (Customer2.nodes.get(email="jim@bob.com")).age == 7 assert (Customer2.nodes.get(email="jim1@bob.com")).age == 2 diff --git a/test/sync_/test_relationships.py b/test/sync_/test_relationships.py index 44c6010a..13ca9295 100644 --- a/test/sync_/test_relationships.py +++ b/test/sync_/test_relationships.py @@ -110,7 +110,6 @@ def test_either_direction_connect(): assert isinstance(rels[0], StructuredRel) -# TODO : Make async-independent test to test .filter and not .filter.all() ? @mark_sync_test def test_search_and_filter_and_exclude(): fred = PersonWithRels(name="Fred", age=13).save() diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index 309637f6..03e42e05 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -15,7 +15,7 @@ class APerson(StructuredNode): @mark_sync_test def test_rollback_and_commit_transaction(): - for p in APerson.nodes.all(): + for p in APerson.nodes: p.delete() APerson(name="Roger").save() @@ -24,13 +24,13 @@ def test_rollback_and_commit_transaction(): APerson(name="Terry S").save() db.rollback() - assert len(APerson.nodes.all()) == 1 + assert len(APerson.nodes) == 1 db.begin() APerson(name="Terry S").save() db.commit() - assert len(APerson.nodes.all()) == 2 + assert len(APerson.nodes) == 2 @db.transaction @@ -43,7 +43,7 @@ def in_a_tx(*names): @mark_sync_test def test_transaction_decorator(): db.install_labels(APerson) - for p in APerson.nodes.all(): + for p in APerson.nodes: p.delete() # should work @@ -54,7 +54,7 @@ def test_transaction_decorator(): with raises(UniqueProperty): in_a_tx("Jim", "Roger") - assert "Jim" not in [p.name for p in APerson.nodes.all()] + assert "Jim" not in [p.name for p in APerson.nodes] @mark_sync_test @@ -71,14 +71,14 @@ def test_transaction_as_a_context(): @mark_sync_test def test_query_inside_transaction(): - for p in APerson.nodes.all(): + for p in APerson.nodes: p.delete() with db.transaction: APerson(name="Alice").save() APerson(name="Bob").save() - assert len([p.name for p in APerson.nodes.all()]) == 2 + assert len([p.name for p in APerson.nodes]) == 2 @mark_sync_test @@ -86,7 +86,7 @@ def test_read_transaction(): APerson(name="Johnny").save() with db.read_transaction: - people = APerson.nodes.all() + people = APerson.nodes assert people with raises(TransactionError): @@ -122,7 +122,7 @@ def in_a_tx(*names): @mark_sync_test def test_bookmark_transaction_decorator(): - for p in APerson.nodes.all(): + for p in APerson.nodes: p.delete() # should work @@ -134,7 +134,7 @@ def test_bookmark_transaction_decorator(): with raises(UniqueProperty): in_a_tx("Jane", "Ruth") - assert "Jane" not in [p.name for p in APerson.nodes.all()] + assert "Jane" not in [p.name for p in APerson.nodes] @mark_sync_test @@ -184,13 +184,13 @@ def test_bookmark_passed_in_to_context(spy_on_db_begin): @mark_sync_test def test_query_inside_bookmark_transaction(): - for p in APerson.nodes.all(): + for p in APerson.nodes: p.delete() with db.transaction as transaction: APerson(name="Alice").save() APerson(name="Bob").save() - assert len([p.name for p in APerson.nodes.all()]) == 2 + assert len([p.name for p in APerson.nodes]) == 2 assert isinstance(transaction.last_bookmark, Bookmarks) From 65facf297abdf88cc9883caac63ef63d5fdd4c7e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 11 Mar 2024 14:38:08 +0100 Subject: [PATCH 45/73] Fix transaction decorator --- neomodel/async_/core.py | 13 +++++++---- neomodel/sync_/core.py | 3 +++ test/async_/test_driver_options.py | 4 ++-- test/async_/test_transactions.py | 37 ++++++++++++++++-------------- test/sync_/test_transactions.py | 9 +++++--- 5 files changed, 39 insertions(+), 27 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 3deee79f..a7c8797b 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -914,12 +914,14 @@ def __init__(self, db: AsyncDatabase, access_mode=None): self.access_mode = access_mode @ensure_connection - async def __enter__(self): + async def __aenter__(self): + print("aenter called") await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self - async def __exit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type, exc_value, traceback): + print("aexit called") if exc_value: await self.db.rollback() @@ -933,9 +935,10 @@ async def __exit__(self, exc_type, exc_value, traceback): self.last_bookmark = await self.db.commit() def __call__(self, func): - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) + async def wrapper(*args, **kwargs): + async with self: + print("call called") + return await func(*args, **kwargs) return wrapper diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 9be231f7..93ef473e 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -911,11 +911,13 @@ def __init__(self, db: Database, access_mode=None): @ensure_connection def __enter__(self): + print("aenter called") self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) self.bookmarks = None return self def __exit__(self, exc_type, exc_value, traceback): + print("aexit called") if exc_value: self.db.rollback() @@ -931,6 +933,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __call__(self, func): def wrapper(*args, **kwargs): with self: + print("call called") return func(*args, **kwargs) return wrapper diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py index dff17c75..6add5f0a 100644 --- a/test/async_/test_driver_options.py +++ b/test/async_/test_driver_options.py @@ -34,11 +34,11 @@ async def test_impersonate_unauthorized(): ) async def test_impersonate_multiple_transactions(): with adb.impersonate(user="troygreene"): - with adb.transaction: + async with adb.transaction: results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" - with adb.transaction: + async with adb.transaction: results, _ = await adb.cypher_query("SHOW CURRENT USER") assert results[0][0] == "troygreene" diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index 331e67f2..139b49f6 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -39,7 +39,8 @@ async def in_a_tx(*names): await APerson(name=n).save() -# TODO : understand how to make @adb.transaction work with async +# TODO : This fails with no support for context manager protocol +# Possibly the transaction decorator is the issue @mark_async_test async def test_transaction_decorator(): await adb.install_labels(APerson) @@ -59,13 +60,13 @@ async def test_transaction_decorator(): @mark_async_test async def test_transaction_as_a_context(): - with adb.transaction: + async with adb.transaction: await APerson(name="Tim").save() - assert await APerson.nodes.filter(name="Tim").all() + assert await APerson.nodes.filter(name="Tim") with raises(UniqueProperty): - with adb.transaction: + async with adb.transaction: await APerson(name="Tim").save() @@ -74,7 +75,7 @@ async def test_query_inside_transaction(): for p in await APerson.nodes: await p.delete() - with adb.transaction: + async with adb.transaction: await APerson(name="Alice").save() await APerson(name="Bob").save() @@ -85,12 +86,12 @@ async def test_query_inside_transaction(): async def test_read_transaction(): await APerson(name="Johnny").save() - with adb.read_transaction: + async with adb.read_transaction: people = await APerson.nodes assert people with raises(TransactionError): - with adb.read_transaction: + async with adb.read_transaction: with raises(ClientError) as e: await APerson(name="Gina").save() assert e.value.code == "Neo.ClientError.Statement.AccessMode" @@ -98,7 +99,7 @@ async def test_read_transaction(): @mark_async_test async def test_write_transaction(): - with adb.write_transaction: + async with adb.write_transaction: await APerson(name="Amelia").save() amelia = await APerson.nodes.get(name="Amelia") @@ -120,6 +121,7 @@ async def in_a_tx(*names): await APerson(name=n).save() +# TODO : FIx this once decorator is fixed @mark_async_test async def test_bookmark_transaction_decorator(): for p in await APerson.nodes: @@ -139,22 +141,22 @@ async def test_bookmark_transaction_decorator(): @mark_async_test async def test_bookmark_transaction_as_a_context(): - with adb.transaction as transaction: - APerson(name="Tanya").save() + async with adb.transaction as transaction: + await APerson(name="Tanya").save() assert isinstance(transaction.last_bookmark, Bookmarks) - assert APerson.nodes.filter(name="Tanya") + assert await APerson.nodes.filter(name="Tanya") with raises(UniqueProperty): - with adb.transaction as transaction: - APerson(name="Tanya").save() + async with adb.transaction as transaction: + await APerson(name="Tanya").save() assert not hasattr(transaction, "last_bookmark") @pytest.fixture async def spy_on_db_begin(monkeypatch): spy_calls = [] - original_begin = adb.begin + original_begin = await adb.begin() def begin_spy(*args, **kwargs): spy_calls.append((args, kwargs)) @@ -164,17 +166,18 @@ def begin_spy(*args, **kwargs): return spy_calls +# TODO : Fix this test @mark_async_test async def test_bookmark_passed_in_to_context(spy_on_db_begin): transaction = adb.transaction - with transaction: + async with transaction: pass assert (await spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) last_bookmark = transaction.last_bookmark transaction.bookmarks = last_bookmark - with transaction: + async with transaction: pass assert spy_on_db_begin[-1] == ( (), @@ -187,7 +190,7 @@ async def test_query_inside_bookmark_transaction(): for p in await APerson.nodes: await p.delete() - with adb.transaction as transaction: + async with adb.transaction as transaction: await APerson(name="Alice").save() await APerson(name="Bob").save() diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index 03e42e05..78ad4c58 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -39,7 +39,8 @@ def in_a_tx(*names): APerson(name=n).save() -# TODO : understand how to make @adb.transaction work with async +# TODO : This fails with no support for context manager protocol +# Possibly the transaction decorator is the issue @mark_sync_test def test_transaction_decorator(): db.install_labels(APerson) @@ -62,7 +63,7 @@ def test_transaction_as_a_context(): with db.transaction: APerson(name="Tim").save() - assert APerson.nodes.filter(name="Tim").all() + assert APerson.nodes.filter(name="Tim") with raises(UniqueProperty): with db.transaction: @@ -120,6 +121,7 @@ def in_a_tx(*names): APerson(name=n).save() +# TODO : FIx this once decorator is fixed @mark_sync_test def test_bookmark_transaction_decorator(): for p in APerson.nodes: @@ -154,7 +156,7 @@ def test_bookmark_transaction_as_a_context(): @pytest.fixture def spy_on_db_begin(monkeypatch): spy_calls = [] - original_begin = db.begin + original_begin = db.begin() def begin_spy(*args, **kwargs): spy_calls.append((args, kwargs)) @@ -164,6 +166,7 @@ def begin_spy(*args, **kwargs): return spy_calls +# TODO : Fix this test @mark_sync_test def test_bookmark_passed_in_to_context(spy_on_db_begin): transaction = db.transaction From 95bdd1acad6b9854d753d2c95b9840bc29eb5a26 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 11 Mar 2024 17:34:24 +0100 Subject: [PATCH 46/73] Fix return type hint --- neomodel/async_/core.py | 4 ++-- neomodel/sync_/core.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index a7c8797b..7cf634c5 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -407,7 +407,7 @@ async def cypher_query( handle_unique=True, retry_on_session_expire=False, resolve_objects=False, - ) -> (list[list], Tuple[str, ...]): + ) -> Tuple[list[list], Tuple[str, ...]]: """ Runs a query on the database and returns a list of results and their headers. @@ -460,7 +460,7 @@ async def _run_cypher_query( handle_unique, retry_on_session_expire, resolve_objects, - ) -> (list[list], Tuple[str, ...]): + ) -> Tuple[list[list], Tuple[str, ...]]: try: # Retrieve the data start = time.time() diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 93ef473e..d97af6b4 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -405,7 +405,7 @@ def cypher_query( handle_unique=True, retry_on_session_expire=False, resolve_objects=False, - ) -> (list[list], Tuple[str, ...]): + ) -> Tuple[list[list], Tuple[str, ...]]: """ Runs a query on the database and returns a list of results and their headers. @@ -458,7 +458,7 @@ def _run_cypher_query( handle_unique, retry_on_session_expire, resolve_objects, - ) -> (list[list], Tuple[str, ...]): + ) -> Tuple[list[list], Tuple[str, ...]]: try: # Retrieve the data start = time.time() From f428e72c947f8062b2b0815888402e1c4d5c0d4a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 11 Mar 2024 17:44:07 +0100 Subject: [PATCH 47/73] Fix skipif --- neomodel/config.py | 2 -- test/async_/test_connection.py | 6 ++---- test/async_/test_dbms_awareness.py | 1 + test/async_/test_driver_options.py | 20 ++++++++------------ test/async_/test_indexing.py | 5 ++--- test/async_/test_label_install.py | 10 ++++++---- test/sync_/test_connection.py | 6 ++---- test/sync_/test_dbms_awareness.py | 1 + test/sync_/test_driver_options.py | 20 ++++++++------------ test/sync_/test_indexing.py | 5 ++--- test/sync_/test_label_install.py | 8 ++++++-- 11 files changed, 38 insertions(+), 46 deletions(-) diff --git a/neomodel/config.py b/neomodel/config.py index 26a0d626..2e527782 100644 --- a/neomodel/config.py +++ b/neomodel/config.py @@ -2,8 +2,6 @@ from neomodel._version import __version__ -AUTO_INSTALL_LABELS = False - # Use this to connect with automatically created driver # The following options are the default ones that will be used as driver config DATABASE_URL = "bolt://neo4j:foobarbaz@localhost:7687" diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py index 94d9b2dc..36c5922b 100644 --- a/test/async_/test_connection.py +++ b/test/async_/test_connection.py @@ -79,11 +79,9 @@ async def test_config_driver_works(): @mark_async_test -@pytest.mark.skipif( - adb.database_edition != "enterprise", - reason="Skipping test for community edition - no multi database in CE", -) async def test_connect_to_non_default_database(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition - no multi database in CE") database_name = "pastries" await adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") await adb.close_connection() diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py index be9e376b..239991c7 100644 --- a/test/async_/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -6,6 +6,7 @@ from neomodel.util import version_tag_to_integer +# TODO : This calling database_version should be async @mark.skipif( adb.database_version != "5.7.0", reason="Testing a specific database version" ) diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py index 6add5f0a..9408c12b 100644 --- a/test/async_/test_driver_options.py +++ b/test/async_/test_driver_options.py @@ -9,30 +9,27 @@ @mark_async_test -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) async def test_impersonate(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with adb.impersonate(user="troygreene"): results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" @mark_async_test -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) async def test_impersonate_unauthorized(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with adb.impersonate(user="unknownuser"): with raises(ClientError): _ = await adb.cypher_query("RETURN 'Gabagool'") @mark_async_test -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) async def test_impersonate_multiple_transactions(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with adb.impersonate(user="troygreene"): async with adb.transaction: results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") @@ -47,10 +44,9 @@ async def test_impersonate_multiple_transactions(): @mark_async_test -@pytest.mark.skipif( - adb.edition_is_enterprise(), reason="Skipping test for enterprise edition" -) async def test_impersonate_community(): + if await adb.edition_is_enterprise(): + pytest.skip("Skipping test for enterprise edition") with raises(FeatureNotSupported): with adb.impersonate(user="troygreene"): _ = await adb.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py index 2933f045..177ec0a2 100644 --- a/test/async_/test_indexing.py +++ b/test/async_/test_indexing.py @@ -32,10 +32,9 @@ async def test_unique_error(): @mark_async_test -@pytest.mark.skipif( - not adb.edition_is_enterprise(), reason="Skipping test for community edition" -) async def test_existence_constraint_error(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") await adb.cypher_query( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" ) diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py index 3ccac1b7..6078c752 100644 --- a/test/async_/test_label_install.py +++ b/test/async_/test_label_install.py @@ -118,10 +118,10 @@ async def test_install_labels_db_property(capsys): await _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") -@pytest.mark.skipif( - adb.version_is_higher_than("5.7"), reason="Not supported before 5.7" -) def test_relationship_unique_index_not_supported(): + if adb.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + class UniqueIndexRelationship(AsyncStructuredRel): name = StringProperty(unique_index=True) @@ -141,8 +141,10 @@ class NodeWithUniqueIndexRelationship(AsyncStructuredNode): @mark_async_test -@pytest.mark.skipif(not adb.version_is_higher_than("5.7"), reason="Supported from 5.7") async def test_relationship_unique_index(): + if not adb.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + class UniqueIndexRelationshipBis(AsyncStructuredRel): name = StringProperty(unique_index=True) diff --git a/test/sync_/test_connection.py b/test/sync_/test_connection.py index 666321a1..cc77df18 100644 --- a/test/sync_/test_connection.py +++ b/test/sync_/test_connection.py @@ -77,11 +77,9 @@ def test_config_driver_works(): @mark_sync_test -@pytest.mark.skipif( - db.database_edition != "enterprise", - reason="Skipping test for community edition - no multi database in CE", -) def test_connect_to_non_default_database(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition - no multi database in CE") database_name = "pastries" db.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") db.close_connection() diff --git a/test/sync_/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py index 1d815eff..6694ddfe 100644 --- a/test/sync_/test_dbms_awareness.py +++ b/test/sync_/test_dbms_awareness.py @@ -6,6 +6,7 @@ from neomodel.util import version_tag_to_integer +# TODO : This calling database_version should be async @mark.skipif( db.database_version != "5.7.0", reason="Testing a specific database version" ) diff --git a/test/sync_/test_driver_options.py b/test/sync_/test_driver_options.py index 5e5e12b9..c4deb59c 100644 --- a/test/sync_/test_driver_options.py +++ b/test/sync_/test_driver_options.py @@ -9,30 +9,27 @@ @mark_sync_test -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) def test_impersonate(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with db.impersonate(user="troygreene"): results, _ = db.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" @mark_sync_test -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) def test_impersonate_unauthorized(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with db.impersonate(user="unknownuser"): with raises(ClientError): _ = db.cypher_query("RETURN 'Gabagool'") @mark_sync_test -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) def test_impersonate_multiple_transactions(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with db.impersonate(user="troygreene"): with db.transaction: results, _ = db.cypher_query("RETURN 'Doo Wacko !'") @@ -47,10 +44,9 @@ def test_impersonate_multiple_transactions(): @mark_sync_test -@pytest.mark.skipif( - db.edition_is_enterprise(), reason="Skipping test for enterprise edition" -) def test_impersonate_community(): + if db.edition_is_enterprise(): + pytest.skip("Skipping test for enterprise edition") with raises(FeatureNotSupported): with db.impersonate(user="troygreene"): _ = db.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/sync_/test_indexing.py b/test/sync_/test_indexing.py index db1c3256..c50a53f6 100644 --- a/test/sync_/test_indexing.py +++ b/test/sync_/test_indexing.py @@ -27,10 +27,9 @@ def test_unique_error(): @mark_sync_test -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) def test_existence_constraint_error(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") db.cypher_query( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" ) diff --git a/test/sync_/test_label_install.py b/test/sync_/test_label_install.py index 60309dfc..74235b74 100644 --- a/test/sync_/test_label_install.py +++ b/test/sync_/test_label_install.py @@ -118,8 +118,10 @@ def test_install_labels_db_property(capsys): _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") -@pytest.mark.skipif(db.version_is_higher_than("5.7"), reason="Not supported before 5.7") def test_relationship_unique_index_not_supported(): + if db.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + class UniqueIndexRelationship(StructuredRel): name = StringProperty(unique_index=True) @@ -139,8 +141,10 @@ class NodeWithUniqueIndexRelationship(StructuredNode): @mark_sync_test -@pytest.mark.skipif(not db.version_is_higher_than("5.7"), reason="Supported from 5.7") def test_relationship_unique_index(): + if not db.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + class UniqueIndexRelationshipBis(StructuredRel): name = StringProperty(unique_index=True) From 5b8825c1dad81ec4e7842603e124c3815ff8b001 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 09:58:06 +0100 Subject: [PATCH 48/73] Propagate async database_version ; fix __await__ dunders --- .pre-commit-config.yaml | 2 +- bin/make-unasync | 6 ++ neomodel/async_/cardinality.py | 6 +- neomodel/async_/core.py | 78 +++++++++++++++---------- neomodel/async_/match.py | 65 +++++++++++++-------- neomodel/async_/relationship.py | 27 +++------ neomodel/async_/relationship_manager.py | 41 +++++++------ neomodel/sync_/core.py | 44 +++++++++----- neomodel/sync_/match.py | 31 +++++++--- neomodel/sync_/relationship.py | 21 ++----- neomodel/sync_/relationship_manager.py | 3 +- run-unasync.sh | 3 + test/async_/conftest.py | 3 +- test/async_/test_cypher.py | 8 ++- test/async_/test_dbms_awareness.py | 15 ++--- test/async_/test_driver_options.py | 8 +-- test/async_/test_match_api.py | 61 +++++++++---------- test/async_/test_migration_neo4j_5.py | 3 +- test/async_/test_models.py | 3 +- test/async_/test_relationships.py | 2 +- test/sync_/conftest.py | 3 +- test/sync_/test_dbms_awareness.py | 15 ++--- test/sync_/test_match_api.py | 41 ++++++------- test/sync_/test_migration_neo4j_5.py | 3 +- test/sync_/test_models.py | 3 +- 25 files changed, 284 insertions(+), 211 deletions(-) create mode 100644 run-unasync.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd58a3c9..5cabda9d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: hooks: - id: unasync name: unasync - entry: bin/make-unasync + entry: bash run-unasync.sh language: system files: "^(neomodel/async_|test/async_)/.*" - repo: https://github.com/psf/black diff --git a/bin/make-unasync b/bin/make-unasync index 5acdf5e7..a9d1d6b7 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -209,11 +209,17 @@ def apply_unasync(files): "async_": "sync_", "check_bool": "__bool__", "check_nonzero": "__nonzero__", + "check_contains": "__contains__", + "get_item": "__getitem__", + "get_len": "__len__", } additional_test_replacements = { "async_": "sync_", "check_bool": "__bool__", "check_nonzero": "__nonzero__", + "check_contains": "__contains__", + "get_item": "__getitem__", + "get_len": "__len__", "adb": "db", "mark_async_test": "mark_sync_test", "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture", diff --git a/neomodel/async_/cardinality.py b/neomodel/async_/cardinality.py index ff1eb779..17101cec 100644 --- a/neomodel/async_/cardinality.py +++ b/neomodel/async_/cardinality.py @@ -37,7 +37,7 @@ async def connect(self, node, properties=None): :type: dict :return: True / rel instance """ - if await super().__len__(): + if await super().get_len(): raise AttemptedCardinalityViolation( f"Node already has {self} can't connect more" ) @@ -77,7 +77,7 @@ async def disconnect(self, node): :param node: :return: """ - if await super().__len__() < 2: + if await super().get_len() < 2: raise AttemptedCardinalityViolation("One or more expected") return await super().disconnect(node) @@ -130,6 +130,6 @@ async def connect(self, node, properties=None): """ if not hasattr(self.source, "element_id") or self.source.element_id is None: raise ValueError("Node has not been saved cannot connect!") - if await super().__len__(): + if await super().get_len(): raise AttemptedCardinalityViolation("Node already has one relationship") return await super().connect(node, properties) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 7cf634c5..1e555a22 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -52,6 +52,16 @@ # make sure the connection url has been set prior to executing the wrapped function def ensure_connection(func): + """Decorator that ensures a connection is established before executing the decorated function. + + Args: + func (callable): The function to be decorated. + + Returns: + callable: The decorated function. + + """ + async def wrapper(self, *args, **kwargs): # Sort out where to find url if hasattr(self, "db"): @@ -60,10 +70,10 @@ async def wrapper(self, *args, **kwargs): _db = self if not _db.driver: - if hasattr(config, "DRIVER") and config.DRIVER: - await _db.set_connection(driver=config.DRIVER) - elif config.DATABASE_URL: + if hasattr(config, "DATABASE_URL") and config.DATABASE_URL: await _db.set_connection(url=config.DATABASE_URL) + elif hasattr(config, "DRIVER") and config.DRIVER: + await _db.set_connection(driver=config.DRIVER) return await func(self, *args, **kwargs) @@ -194,17 +204,18 @@ async def close_connection(self): await self.driver.close() self.driver = None + # TODO : Make this async and turn on muck-spreader @property - def database_version(self): + async def database_version(self): if self._database_version is None: - self._update_database_version() + await self._update_database_version() return self._database_version @property - def database_edition(self): + async def database_edition(self): if self._database_edition is None: - self._update_database_version() + await self._update_database_version() return self._database_edition @@ -223,7 +234,7 @@ def write_transaction(self): def read_transaction(self): return AsyncTransactionProxy(self, access_mode="READ") - def impersonate(self, user: str) -> "ImpersonationHandler": + async def impersonate(self, user: str) -> "ImpersonationHandler": """All queries executed within this context manager will be executed as impersonated user Args: @@ -232,7 +243,8 @@ def impersonate(self, user: str) -> "ImpersonationHandler": Returns: ImpersonationHandler: Context manager to set/unset the user to impersonate """ - if self.database_edition != "enterprise": + db_edition = await self.database_edition + if db_edition != "enterprise": raise FeatureNotSupported( "Impersonation is only available in Neo4j Enterprise edition" ) @@ -505,8 +517,9 @@ async def _run_cypher_query( return results, meta - def get_id_method(self) -> str: - if self.database_version.startswith("4"): + async def get_id_method(self) -> str: + db_version = await self.database_version + if db_version.startswith("4"): return "id" else: return "elementId" @@ -551,9 +564,8 @@ async def version_is_higher_than(self, version_tag: str) -> bool: Returns: bool: True if the database version is higher or equal to the given version """ - return version_tag_to_integer(self.database_version) >= version_tag_to_integer( - version_tag - ) + db_version = await self.database_version + return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag) @ensure_connection async def edition_is_enterprise(self) -> bool: @@ -562,7 +574,8 @@ async def edition_is_enterprise(self) -> bool: Returns: bool: True if the database edition is enterprise """ - return self.database_edition == "enterprise" + edition = await self.database_edition + return edition == "enterprise" async def change_neo4j_password(self, user, new_password): await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") @@ -767,7 +780,7 @@ async def _create_relationship_constraint( raise else: raise FeatureNotSupported( - f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." + f"Unique indexes on relationships are not supported in Neo4j version {await self.database_version}. Please upgrade to Neo4j 5.7 or higher." ) async def _install_node(self, cls, name, property, quiet, stdout): @@ -1125,14 +1138,11 @@ def nodes(cls): return AsyncNodeSet(cls) + # TODO : Update places where element_id is expected to be an int (where id(n)=$element_id) @property def element_id(self): if hasattr(self, "element_id_property"): - return ( - int(self.element_id_property) - if adb.database_version.startswith("4") - else self.element_id_property - ) + return self.element_id_property return None # Version 4.4 support - id is deprecated in version 5.x @@ -1148,7 +1158,7 @@ def id(self): # methods @classmethod - def _build_merge_query( + async def _build_merge_query( cls, merge_params, update_existing=False, lazy=False, relationship=None ): """ @@ -1187,7 +1197,7 @@ def _build_merge_query( from neomodel.async_.match import _rel_helper query_params["source_id"] = relationship.source.element_id - query = f"MATCH (source:{relationship.source.__label__}) WHERE {adb.get_id_method()}(source) = $source_id\n " + query = f"MATCH (source:{relationship.source.__label__}) WHERE {await adb.get_id_method()}(source) = $source_id\n " query += "WITH source\n UNWIND $merge_params as params \n " query += "MERGE " query += _rel_helper( @@ -1205,7 +1215,7 @@ def _build_merge_query( # close query if lazy: - query += f"RETURN {adb.get_id_method()}(n)" + query += f"RETURN {await adb.get_id_method()}(n)" else: query += "RETURN n" @@ -1236,7 +1246,7 @@ async def create(cls, *props, **kwargs): # close query if lazy: - query += f" RETURN {adb.get_id_method()}(n)" + query += f" RETURN {await adb.get_id_method()}(n)" else: query += " RETURN n" @@ -1285,7 +1295,7 @@ async def create_or_update(cls, *props, **kwargs): ), } ) - query, params = cls._build_merge_query( + query, params = await cls._build_merge_query( create_or_update_params, update_existing=True, relationship=relationship, @@ -1316,7 +1326,11 @@ async def cypher(self, query, params=None): """ self._pre_action_check("cypher") params = params or {} - params.update({"self": self.element_id}) + db_version = await adb.database_version + element_id = ( + int(self.element_id) if db_version.startswith("4") else self.element_id + ) + params.update({"self": element_id}) return await adb.cypher_query(query, params) @hooks @@ -1328,7 +1342,7 @@ async def delete(self): """ self._pre_action_check("delete") await self.cypher( - f"MATCH (self) WHERE {adb.get_id_method()}(self)=$self DETACH DELETE self" + f"MATCH (self) WHERE {await adb.get_id_method()}(self)=$self DETACH DELETE self" ) delattr(self, "element_id_property") self.deleted = True @@ -1357,7 +1371,7 @@ async def get_or_create(cls, *props, **kwargs): get_or_create_params = [ {"create": cls.deflate(p, skip_empty=True)} for p in props ] - query, params = cls._build_merge_query( + query, params = await cls._build_merge_query( get_or_create_params, relationship=relationship, lazy=lazy ) @@ -1439,7 +1453,7 @@ async def labels(self): """ self._pre_action_check("labels") result = await self.cypher( - f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self " "RETURN labels(n)" + f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self " "RETURN labels(n)" ) return result[0][0][0] @@ -1460,7 +1474,7 @@ async def refresh(self): self._pre_action_check("refresh") if hasattr(self, "element_id"): results = await self.cypher( - f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self RETURN n" + f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self RETURN n" ) request = results[0] if not request or not request[0]: @@ -1483,7 +1497,7 @@ async def save(self): if hasattr(self, "element_id_property"): # update params = self.deflate(self.__properties__, self) - query = f"MATCH (n) WHERE {adb.get_id_method()}(n)=$self\n" + query = f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self\n" if params: query += "SET " diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 5ba1ce89..76f058a1 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -362,12 +362,12 @@ def __init__(self, node_set): self._ident_count = 0 self._node_counters = defaultdict(int) - def build_ast(self): + async def build_ast(self): if hasattr(self.node_set, "relations_to_fetch"): for relation in self.node_set.relations_to_fetch: self.build_traversal_from_path(relation, self.node_set.source) - self.build_source(self.node_set) + await self.build_source(self.node_set) if hasattr(self.node_set, "skip"): self._ast.skip = self.node_set.skip @@ -376,16 +376,16 @@ def build_ast(self): return self - def build_source(self, source): + async def build_source(self, source): if isinstance(source, AsyncTraversal): - return self.build_traversal(source) + return await self.build_traversal(source) if isinstance(source, AsyncNodeSet): if inspect.isclass(source.source) and issubclass( source.source, AsyncStructuredNode ): ident = self.build_label(source.source.__label__.lower(), source.source) else: - ident = self.build_source(source.source) + ident = await self.build_source(source.source) self.build_additional_match(ident, source) @@ -402,7 +402,7 @@ def build_source(self, source): return ident if isinstance(source, AsyncStructuredNode): - return self.build_node(source) + return await self.build_node(source) raise ValueError("Unknown source type " + repr(source)) def create_ident(self): @@ -416,7 +416,7 @@ def build_order_by(self, ident, source): else: self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] - def build_traversal(self, traversal): + async def build_traversal(self, traversal): """ traverse a relationship from a node to a set of nodes """ @@ -425,7 +425,7 @@ def build_traversal(self, traversal): # build source rel_ident = self.create_ident() - lhs_ident = self.build_source(traversal.source) + lhs_ident = await self.build_source(traversal.source) traversal_ident = f"{traversal.name}_{rel_ident}" rhs_ident = traversal_ident + rhs_label self._ast.return_clause = traversal_ident @@ -496,12 +496,12 @@ def build_traversal_from_path(self, relation: dict, source_class) -> str: self._ast.match.append(stmt) return rhs_name - def build_node(self, node): + async def build_node(self, node): ident = node.__class__.__name__.lower() place_holder = self._register_place_holder(ident) # Hack to emulate START to lookup a node by id - _node_lookup = f"MATCH ({ident}) WHERE {adb.get_id_method()}({ident})=${place_holder} WITH {ident}" + _node_lookup = f"MATCH ({ident}) WHERE {await adb.get_id_method()}({ident})=${place_holder} WITH {ident}" self._ast.lookup = _node_lookup self._query_params[place_holder] = node.element_id @@ -688,7 +688,9 @@ async def _contains(self, node_element_id): self._ast.return_clause = self._ast.additional_return[0] ident = self._ast.return_clause place_holder = self._register_place_holder(ident + "_contains") - self._ast.where.append(f"{adb.get_id_method()}({ident}) = ${place_holder}") + self._ast.where.append( + f"{await adb.get_id_method()}({ident}) = ${place_holder}" + ) self._query_params[place_holder] = node_element_id return await self._count() >= 1 @@ -697,11 +699,11 @@ async def _execute(self, lazy=False): # inject id() into return or return_set if self._ast.return_clause: self._ast.return_clause = ( - f"{adb.get_id_method()}({self._ast.return_clause})" + f"{await adb.get_id_method()}({self._ast.return_clause})" ) else: self._ast.additional_return = [ - f"{adb.get_id_method()}({item})" + f"{await adb.get_id_method()}({item})" for item in self._ast.additional_return ] query = self.build_query() @@ -733,14 +735,24 @@ async def all(self, lazy=False): :return: list of nodes :rtype: list """ - return await self.query_cls(self).build_ast()._execute(lazy) + ast = await self.query_cls(self).build_ast() + return await ast._execute(lazy) async def __aiter__(self): - async for i in await self.query_cls(self).build_ast()._execute(): + ast = await self.query_cls(self).build_ast() + async for i in await ast._execute(): yield i - async def __len__(self): - return await self.query_cls(self).build_ast()._count() + # TODO : Add tests for sync to check that len(Label.nodes) is still working + # Because async tests will now check for Coffee.nodes.get_len() + # Also add documenation for get_len, check_bool, etc... + # Documentation should explain that in sync, assert node1.extension is more efficient than + # assert node1.extension.check_bool() because it counts using native Cypher + # Same for len(Extension.nodes) vs Extension.nodes.__len__ + # With note that async does not have a choice + async def get_len(self): + ast = await self.query_cls(self).build_ast() + return await ast._count() async def check_bool(self): """ @@ -748,7 +760,8 @@ async def check_bool(self): :return: True if the set contains any nodes, False otherwise :rtype: bool """ - _count = await self.query_cls(self).build_ast()._count() + ast = await self.query_cls(self).build_ast() + _count = ast._count() return _count > 0 async def check_nonzero(self): @@ -759,15 +772,16 @@ async def check_nonzero(self): """ return await self.check_bool() - def __contains__(self, obj): + async def check_contains(self, obj): if isinstance(obj, AsyncStructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: - return self.query_cls(self).build_ast()._contains(obj.element_id) + ast = await self.query_cls(self).build_ast() + return await ast._contains(obj.element_id) raise ValueError("Unsaved node: " + repr(obj)) raise ValueError("Expecting StructuredNode instance") - async def __getitem__(self, key): + async def get_item(self, key): if isinstance(key, slice): if key.stop and key.start: self.limit = key.stop - key.start @@ -783,7 +797,8 @@ async def __getitem__(self, key): self.skip = key self.limit = 1 - _items = await self.query_cls(self).build_ast()._execute() + ast = await self.query_cls(self).build_ast() + _items = ast._execute() return _items[0] return None @@ -831,7 +846,8 @@ async def _get(self, limit=None, lazy=False, **kwargs): self.filter(**kwargs) if limit: self.limit = limit - return await self.query_cls(self).build_ast()._execute(lazy) + ast = await self.query_cls(self).build_ast() + return await ast._execute(lazy) async def get(self, lazy=False, **kwargs): """ @@ -1003,6 +1019,9 @@ class AsyncTraversal(AsyncBaseSet): :type defintion: :class:`dict` """ + def __await__(self): + return self.all().__await__() + def __init__(self, source, name, definition): """ Create a traversal diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index 65c51627..cb976bf1 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -49,27 +49,18 @@ def __init__(self, *args, **kwargs): @property def element_id(self): - return ( - int(self.element_id_property) - if adb.database_version.startswith("4") - else self.element_id_property - ) + if hasattr(self, "element_id_property"): + return self.element_id_property @property def _start_node_element_id(self): - return ( - int(self._start_node_element_id_property) - if adb.database_version.startswith("4") - else self._start_node_element_id_property - ) + if hasattr(self, "_start_node_element_id_property"): + return self._start_node_element_id_property @property def _end_node_element_id(self): - return ( - int(self._end_node_element_id_property) - if adb.database_version.startswith("4") - else self._end_node_element_id_property - ) + if hasattr(self, "_end_node_element_id_property"): + return self._end_node_element_id_property # Version 4.4 support - id is deprecated in version 5.x @property @@ -109,7 +100,7 @@ async def save(self): :return: self """ props = self.deflate(self.__properties__) - query = f"MATCH ()-[r]->() WHERE {adb.get_id_method()}(r)=$self " + query = f"MATCH ()-[r]->() WHERE {await adb.get_id_method()}(r)=$self " query += "".join([f" SET r.{key} = ${key}" for key in props]) props["self"] = self.element_id @@ -126,7 +117,7 @@ async def start_node(self): results = await adb.cypher_query( f""" MATCH (aNode) - WHERE {adb.get_id_method()}(aNode)=$start_node_element_id + WHERE {await adb.get_id_method()}(aNode)=$start_node_element_id RETURN aNode """, {"start_node_element_id": self._start_node_element_id}, @@ -143,7 +134,7 @@ async def end_node(self): results = await adb.cypher_query( f""" MATCH (aNode) - WHERE {adb.get_id_method()}(aNode)=$end_node_element_id + WHERE {await adb.get_id_method()}(aNode)=$end_node_element_id RETURN aNode """, {"end_node_element_id": self._end_node_element_id}, diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 145e4247..e188c15d 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -69,7 +69,7 @@ def __str__(self): return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" - async def __await__(self): + def __await__(self): return self.all().__await__() def _check_node(self, obj): @@ -126,7 +126,7 @@ async def connect(self, node, properties=None): **self.definition, ) q = ( - f"MATCH (them), (us) WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self " + f"MATCH (them), (us) WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self " "MERGE" + new_rel ) @@ -171,7 +171,7 @@ async def relationship(self, node): q = ( "MATCH " + my_rel - + f" WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r LIMIT 1" + + f" WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r LIMIT 1" ) results = await self.source.cypher(q, {"them": node.element_id}) rels = results[0] @@ -193,7 +193,7 @@ async def all_relationships(self, node): self._check_node(node) my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) - q = f"MATCH {my_rel} WHERE {adb.get_id_method()}(them)=$them and {adb.get_id_method()}(us)=$self RETURN r " + q = f"MATCH {my_rel} WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r " results = await self.source.cypher(q, {"them": node.element_id}) rels = results[0] if not rels: @@ -234,7 +234,7 @@ async def reconnect(self, old_node, new_node): # get list of properties on the existing rel result, _ = await self.source.cypher( f""" - MATCH (us), (old) WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old + MATCH (us), (old) WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old MATCH {old_rel} RETURN r """, {"old": old_node.element_id}, @@ -249,7 +249,7 @@ async def reconnect(self, old_node, new_node): new_rel = _rel_merge_helper(lhs="us", rhs="new", ident="r2", **self.definition) q = ( "MATCH (us), (old), (new) " - f"WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(old)=$old and {adb.get_id_method()}(new)=$new " + f"WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old and {await adb.get_id_method()}(new)=$new " "MATCH " + old_rel ) q += " MERGE" + new_rel @@ -272,7 +272,7 @@ async def disconnect(self, node): """ rel = _rel_helper(lhs="a", rhs="b", ident="r", **self.definition) q = f""" - MATCH (a), (b) WHERE {adb.get_id_method()}(a)=$self and {adb.get_id_method()}(b)=$them + MATCH (a), (b) WHERE {await adb.get_id_method()}(a)=$self and {await adb.get_id_method()}(b)=$them MATCH {rel} DELETE r """ await self.source.cypher(q, {"them": node.element_id}) @@ -286,7 +286,11 @@ async def disconnect_all(self): """ rhs = "b:" + self.definition["node_class"].__label__ rel = _rel_helper(lhs="a", rhs=rhs, ident="r", **self.definition) - q = f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self MATCH " + rel + " DELETE r" + q = ( + f"MATCH (a) WHERE {await adb.get_id_method()}(a)=$self MATCH " + + rel + + " DELETE r" + ) await self.source.cypher(q) @check_source @@ -356,7 +360,8 @@ async def single(self): :return: StructuredNode """ try: - return await self[0] + rels = await self + return rels[0] except IndexError: pass @@ -380,20 +385,20 @@ async def all(self): async def __aiter__(self): return self._new_traversal().__aiter__() - def __len__(self): - return self._new_traversal().__len__() + async def get_len(self): + return await self._new_traversal().get_len() - def __bool__(self): - return self._new_traversal().check_bool() + async def check_bool(self): + return await self._new_traversal().check_bool() - def __nonzero__(self): + async def check_nonzero(self): return self._new_traversal().check_nonzero() - def __contains__(self, obj): - return self._new_traversal().__contains__(obj) + async def check_contains(self, obj): + return self._new_traversal().check_contains(obj) - def __getitem__(self, key): - return self._new_traversal().__getitem__(key) + async def get_item(self, key): + return self._new_traversal().get_item(key) class AsyncRelationshipDefinition: diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index d97af6b4..b7fdc2fb 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -52,6 +52,16 @@ # make sure the connection url has been set prior to executing the wrapped function def ensure_connection(func): + """Decorator that ensures a connection is established before executing the decorated function. + + Args: + func (callable): The function to be decorated. + + Returns: + callable: The decorated function. + + """ + def wrapper(self, *args, **kwargs): # Sort out where to find url if hasattr(self, "db"): @@ -60,10 +70,10 @@ def wrapper(self, *args, **kwargs): _db = self if not _db.driver: - if hasattr(config, "DRIVER") and config.DRIVER: - _db.set_connection(driver=config.DRIVER) - elif config.DATABASE_URL: + if hasattr(config, "DATABASE_URL") and config.DATABASE_URL: _db.set_connection(url=config.DATABASE_URL) + elif hasattr(config, "DRIVER") and config.DRIVER: + _db.set_connection(driver=config.DRIVER) return func(self, *args, **kwargs) @@ -194,6 +204,7 @@ def close_connection(self): self.driver.close() self.driver = None + # TODO : Make this async and turn on muck-spreader @property def database_version(self): if self._database_version is None: @@ -232,7 +243,8 @@ def impersonate(self, user: str) -> "ImpersonationHandler": Returns: ImpersonationHandler: Context manager to set/unset the user to impersonate """ - if self.database_edition != "enterprise": + db_edition = self.database_edition + if db_edition != "enterprise": raise FeatureNotSupported( "Impersonation is only available in Neo4j Enterprise edition" ) @@ -504,7 +516,8 @@ def _run_cypher_query( return results, meta def get_id_method(self) -> str: - if self.database_version.startswith("4"): + db_version = self.database_version + if db_version.startswith("4"): return "id" else: return "elementId" @@ -549,9 +562,8 @@ def version_is_higher_than(self, version_tag: str) -> bool: Returns: bool: True if the database version is higher or equal to the given version """ - return version_tag_to_integer(self.database_version) >= version_tag_to_integer( - version_tag - ) + db_version = self.database_version + return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag) @ensure_connection def edition_is_enterprise(self) -> bool: @@ -560,7 +572,8 @@ def edition_is_enterprise(self) -> bool: Returns: bool: True if the database edition is enterprise """ - return self.database_edition == "enterprise" + edition = self.database_edition + return edition == "enterprise" def change_neo4j_password(self, user, new_password): self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") @@ -1121,14 +1134,11 @@ def nodes(cls): return NodeSet(cls) + # TODO : Update places where element_id is expected to be an int (where id(n)=$element_id) @property def element_id(self): if hasattr(self, "element_id_property"): - return ( - int(self.element_id_property) - if db.database_version.startswith("4") - else self.element_id_property - ) + return self.element_id_property return None # Version 4.4 support - id is deprecated in version 5.x @@ -1312,7 +1322,11 @@ def cypher(self, query, params=None): """ self._pre_action_check("cypher") params = params or {} - params.update({"self": self.element_id}) + db_version = db.database_version + element_id = ( + int(self.element_id) if db_version.startswith("4") else self.element_id + ) + params.update({"self": element_id}) return db.cypher_query(query, params) @hooks diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 7c0abb16..3ecd3e43 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -731,14 +731,24 @@ def all(self, lazy=False): :return: list of nodes :rtype: list """ - return self.query_cls(self).build_ast()._execute(lazy) + ast = self.query_cls(self).build_ast() + return ast._execute(lazy) def __iter__(self): - for i in self.query_cls(self).build_ast()._execute(): + ast = self.query_cls(self).build_ast() + for i in ast._execute(): yield i + # TODO : Add tests for sync to check that len(Label.nodes) is still working + # Because async tests will now check for Coffee.nodes.get_len() + # Also add documenation for get_len, check_bool, etc... + # Documentation should explain that in sync, assert node1.extension is more efficient than + # assert node1.extension.check_bool() because it counts using native Cypher + # Same for len(Extension.nodes) vs Extension.nodes.__len__ + # With note that async does not have a choice def __len__(self): - return self.query_cls(self).build_ast()._count() + ast = self.query_cls(self).build_ast() + return ast._count() def __bool__(self): """ @@ -746,7 +756,8 @@ def __bool__(self): :return: True if the set contains any nodes, False otherwise :rtype: bool """ - _count = self.query_cls(self).build_ast()._count() + ast = self.query_cls(self).build_ast() + _count = ast._count() return _count > 0 def __nonzero__(self): @@ -760,7 +771,8 @@ def __nonzero__(self): def __contains__(self, obj): if isinstance(obj, StructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: - return self.query_cls(self).build_ast()._contains(obj.element_id) + ast = self.query_cls(self).build_ast() + return ast._contains(obj.element_id) raise ValueError("Unsaved node: " + repr(obj)) raise ValueError("Expecting StructuredNode instance") @@ -781,7 +793,8 @@ def __getitem__(self, key): self.skip = key self.limit = 1 - _items = self.query_cls(self).build_ast()._execute() + ast = self.query_cls(self).build_ast() + _items = ast._execute() return _items[0] return None @@ -829,7 +842,8 @@ def _get(self, limit=None, lazy=False, **kwargs): self.filter(**kwargs) if limit: self.limit = limit - return self.query_cls(self).build_ast()._execute(lazy) + ast = self.query_cls(self).build_ast() + return ast._execute(lazy) def get(self, lazy=False, **kwargs): """ @@ -1001,6 +1015,9 @@ class Traversal(BaseSet): :type defintion: :class:`dict` """ + def __await__(self): + return self.all().__await__() + def __init__(self, source, name, definition): """ Create a traversal diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 3c6aa523..5f0e3f8f 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -49,27 +49,18 @@ def __init__(self, *args, **kwargs): @property def element_id(self): - return ( - int(self.element_id_property) - if db.database_version.startswith("4") - else self.element_id_property - ) + if hasattr(self, "element_id_property"): + return self.element_id_property @property def _start_node_element_id(self): - return ( - int(self._start_node_element_id_property) - if db.database_version.startswith("4") - else self._start_node_element_id_property - ) + if hasattr(self, "_start_node_element_id_property"): + return self._start_node_element_id_property @property def _end_node_element_id(self): - return ( - int(self._end_node_element_id_property) - if db.database_version.startswith("4") - else self._end_node_element_id_property - ) + if hasattr(self, "_end_node_element_id_property"): + return self._end_node_element_id_property # Version 4.4 support - id is deprecated in version 5.x @property diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index 683efe6f..f96e7b25 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -349,7 +349,8 @@ def single(self): :return: StructuredNode """ try: - return self[0] + rels = self + return rels[0] except IndexError: pass diff --git a/run-unasync.sh b/run-unasync.sh new file mode 100644 index 00000000..590c2620 --- /dev/null +++ b/run-unasync.sh @@ -0,0 +1,3 @@ +#!/bin/bash +source venv/bin/activate +bin/make-unasync \ No newline at end of file diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 1b9b76a9..302b07f3 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -40,7 +40,8 @@ async def setup_neo4j_session(request, event_loop): await adb.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) - if adb.database_edition == "enterprise": + db_edition = await adb.database_edition + if db_edition == "enterprise": await adb.cypher_query("GRANT ROLE publisher TO troygreene") await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index c13e0b2a..dc37a5b1 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -45,14 +45,14 @@ async def test_cypher(): jim = await User2(email="jim1@test.com").save() data, meta = await jim.cypher( - f"MATCH (a) WHERE {adb.get_id_method()}(a)=$self RETURN a.email" + f"MATCH (a) WHERE {await adb.get_id_method()}(a)=$self RETURN a.email" ) assert data[0][0] == "jim1@test.com" assert "a.email" in meta data, meta = await jim.cypher( f""" - MATCH (a) WHERE {adb.get_id_method()}(a)=$self + MATCH (a) WHERE {await adb.get_id_method()}(a)=$self MATCH (a)<-[:USER2]-(b) RETURN a, b, 3 """ @@ -64,7 +64,9 @@ async def test_cypher(): async def test_cypher_syntax_error(): jim = await User2(email="jim1@test.com").save() try: - await jim.cypher(f"MATCH a WHERE {adb.get_id_method()}(a)={{self}} RETURN xx") + await jim.cypher( + f"MATCH a WHERE {await adb.get_id_method()}(a)={{self}} RETURN xx" + ) except CypherError as e: assert hasattr(e, "message") assert hasattr(e, "code") diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py index 239991c7..1c309a62 100644 --- a/test/async_/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -1,17 +1,17 @@ from test._async_compat import mark_async_test -from pytest import mark +import pytest from neomodel.async_.core import adb from neomodel.util import version_tag_to_integer -# TODO : This calling database_version should be async -@mark.skipif( - adb.database_version != "5.7.0", reason="Testing a specific database version" -) +@mark_async_test async def test_version_awareness(): - assert adb.database_version == "5.7.0" + db_version = await adb.database_version + if db_version != "5.7.0": + pytest.skip("Testing a specific database version") + assert db_version == "5.7.0" assert await adb.version_is_higher_than("5.7") assert await adb.version_is_higher_than("5.6.0") assert await adb.version_is_higher_than("5") @@ -22,7 +22,8 @@ async def test_version_awareness(): @mark_async_test async def test_edition_awareness(): - if adb.database_edition == "enterprise": + db_edition = await adb.database_edition + if db_edition == "enterprise": assert await adb.edition_is_enterprise() else: assert not await adb.edition_is_enterprise() diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py index 9408c12b..cd14c78f 100644 --- a/test/async_/test_driver_options.py +++ b/test/async_/test_driver_options.py @@ -12,7 +12,7 @@ async def test_impersonate(): if not await adb.edition_is_enterprise(): pytest.skip("Skipping test for community edition") - with adb.impersonate(user="troygreene"): + with await adb.impersonate(user="troygreene"): results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" @@ -21,7 +21,7 @@ async def test_impersonate(): async def test_impersonate_unauthorized(): if not await adb.edition_is_enterprise(): pytest.skip("Skipping test for community edition") - with adb.impersonate(user="unknownuser"): + with await adb.impersonate(user="unknownuser"): with raises(ClientError): _ = await adb.cypher_query("RETURN 'Gabagool'") @@ -30,7 +30,7 @@ async def test_impersonate_unauthorized(): async def test_impersonate_multiple_transactions(): if not await adb.edition_is_enterprise(): pytest.skip("Skipping test for community edition") - with adb.impersonate(user="troygreene"): + with await adb.impersonate(user="troygreene"): async with adb.transaction: results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" @@ -48,5 +48,5 @@ async def test_impersonate_community(): if await adb.edition_is_enterprise(): pytest.skip("Skipping test for enterprise edition") with raises(FeatureNotSupported): - with adb.impersonate(user="troygreene"): + with await adb.impersonate(user="troygreene"): _ = await adb.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 8d05f09f..88367dff 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -58,7 +58,7 @@ async def test_filter_exclude_via_labels(): await Coffee(name="Java", price=99).save() node_set = AsyncNodeSet(Coffee) - qb = AsyncQueryBuilder(node_set).build_ast() + qb = await AsyncQueryBuilder(node_set).build_ast() results = await qb._execute() @@ -71,7 +71,7 @@ async def test_filter_exclude_via_labels(): # with filter and exclude await Coffee(name="Kenco", price=3).save() node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") - qb = AsyncQueryBuilder(node_set).build_ast() + qb = await AsyncQueryBuilder(node_set).build_ast() results = await qb._execute() assert "(coffee:Coffee)" in qb._ast.match @@ -87,7 +87,7 @@ async def test_simple_has_via_label(): await nescafe.suppliers.connect(tesco) ns = AsyncNodeSet(Coffee).has(suppliers=True) - qb = AsyncQueryBuilder(ns).build_ast() + qb = await AsyncQueryBuilder(ns).build_ast() results = await qb._execute() assert "COFFEE SUPPLIERS" in qb._ast.where[0] assert len(results) == 1 @@ -95,7 +95,7 @@ async def test_simple_has_via_label(): await Coffee(name="nespresso", price=99).save() ns = AsyncNodeSet(Coffee).has(suppliers=False) - qb = AsyncQueryBuilder(ns).build_ast() + qb = await AsyncQueryBuilder(ns).build_ast() results = await qb._execute() assert len(results) > 0 assert "NOT" in qb._ast.where[0] @@ -125,7 +125,8 @@ async def test_simple_traverse_with_filter(): AsyncNodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) ) - results = await qb.build_ast()._execute() + _ast = await qb.build_ast() + results = _ast._execute() assert qb._ast.lookup assert qb._ast.match @@ -142,7 +143,7 @@ async def test_double_traverse(): await tesco.coffees.connect(await Coffee(name="Decafe", price=2).save()) ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match() - qb = AsyncQueryBuilder(ns).build_ast() + qb = await AsyncQueryBuilder(ns).build_ast() results = await qb._execute() assert len(results) == 2 @@ -153,14 +154,16 @@ async def test_double_traverse(): @mark_async_test async def test_count(): await Coffee(name="Nescafe Gold", price=99).save() - count = await AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast()._count() + ast = await AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast() + count = await ast._count() assert count > 0 await Coffee(name="Kawa", price=27).save() node_set = AsyncNodeSet(source=Coffee) node_set.skip = 1 node_set.limit = 1 - count = await AsyncQueryBuilder(node_set).build_ast()._count() + ast = await AsyncQueryBuilder(node_set).build_ast() + count = await ast._count() assert count == 1 @@ -190,7 +193,7 @@ async def test_slice(): # TODO : Make slice work with async # Doing await (Coffee.nodes.all())[1:] fetches without slicing - assert len(list(Coffee.nodes.all()[1:])) == 2 + assert len(list((await Coffee.nodes)[1:])) == 2 assert len(list(Coffee.nodes.all()[:1])) == 1 assert isinstance(Coffee.nodes[1], Coffee) assert isinstance(Coffee.nodes[0], Coffee) @@ -208,20 +211,19 @@ async def test_issue_208(): await b.suppliers.connect(l, {"courier": "fedex"}) await b.suppliers.connect(a, {"courier": "dhl"}) - # TODO : Find a way to not need the .all() here - # Note : Check AsyncTraversal match - assert len(await b.suppliers.match(courier="fedex").all()) - assert len(await b.suppliers.match(courier="dhl").all()) + assert len(await b.suppliers.match(courier="fedex")) + assert len(await b.suppliers.match(courier="dhl")) @mark_async_test async def test_issue_589(): node1 = await Extension().save() node2 = await Extension().save() + # TODO : ALso test await node1.extension.check_contains(node2) + # This is the way to pick only a single relationship using async + assert node2 not in await node1.extension await node1.extension.connect(node2) - # TODO : Find a way to not need the .all() here - # Note : Check AsyncRelationshipDefinition (parent of AsyncRelationshipTo / From) - assert node2 in await node1.extension.all() + assert node2 in await node1.extension @mark_async_test @@ -229,17 +231,17 @@ async def test_contains(): expensive = await Coffee(price=1000, name="Pricey").save() asda = await Coffee(name="Asda", price=1).save() - # TODO : Find a way to not need the .all() here - assert expensive in await Coffee.nodes.filter(price__gt=999).all() - assert asda not in await Coffee.nodes.filter(price__gt=999).all() + assert expensive in await Coffee.nodes.filter(price__gt=999) + assert asda not in await Coffee.nodes.filter(price__gt=999) + # TODO : Fix this test # bad value raises with raises(ValueError): - 2 in Coffee.nodes + 2 in await Coffee.nodes # unsaved with raises(ValueError): - Coffee() in Coffee.nodes + Coffee() in await Coffee.nodes @mark_async_test @@ -256,13 +258,13 @@ async def test_order_by(): ns = await Coffee.nodes.order_by("-price") # TODO : Method fails - qb = AsyncQueryBuilder(ns).build_ast() + qb = await AsyncQueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) - qb = AsyncQueryBuilder(ns).build_ast() + qb = await AsyncQueryBuilder(ns).build_ast() assert not qb._ast.order_by ns = ns.order_by("?") - qb = AsyncQueryBuilder(ns).build_ast() + qb = await AsyncQueryBuilder(ns).build_ast() assert qb._ast.with_clause == "coffee, rand() as r" assert qb._ast.order_by == "r" @@ -294,23 +296,22 @@ async def test_extra_filters(): c3 = await Coffee(name="Japans finest", price=35, id_=3).save() c4 = await Coffee(name="US extra-fine", price=None, id_=4).save() - # TODO : Remove some .all() when filter is updated - coffees_5_10 = await Coffee.nodes.filter(price__in=[10, 5]).all() + coffees_5_10 = await Coffee.nodes.filter(price__in=[10, 5]) assert len(coffees_5_10) == 2, "unexpected number of results" assert c1 in coffees_5_10, "doesnt contain 5 price coffee" assert c2 in coffees_5_10, "doesnt contain 10 price coffee" - finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest").all() + finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest") assert len(finest_coffees) == 3, "unexpected number of results" assert c1 in finest_coffees, "doesnt contain 1st finest coffee" assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" - unpriced_coffees = await Coffee.nodes.filter(price__isnull=True).all() + unpriced_coffees = await Coffee.nodes.filter(price__isnull=True) assert len(unpriced_coffees) == 1, "unexpected number of results" assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" - coffees_with_id_gte_3 = await Coffee.nodes.filter(id___gte=3).all() + coffees_with_id_gte_3 = await Coffee.nodes.filter(id___gte=3) assert len(coffees_with_id_gte_3) == 2, "unexpected number of results" assert c3 in coffees_with_id_gte_3 assert c4 in coffees_with_id_gte_3 @@ -319,7 +320,7 @@ async def test_extra_filters(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - await Coffee.nodes.filter(elementId="4:xxx:111").all() + await Coffee.nodes.filter(elementId="4:xxx:111") def test_traversal_definition_keys_are_valid(): diff --git a/test/async_/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py index ceda47b7..346c694a 100644 --- a/test/async_/test_migration_neo4j_5.py +++ b/test/async_/test_migration_neo4j_5.py @@ -38,7 +38,8 @@ async def test_read_elements_id(): # Validate id properties # Behaviour is dependent on Neo4j version - if adb.database_version.startswith("4"): + db_version = await adb.database_version + if db_version.startswith("4"): # Nodes' ids assert lex_hives.id == int(lex_hives.element_id) assert lex_hives.id == (await the_hives.released.single()).id diff --git a/test/async_/test_models.py b/test/async_/test_models.py index f3c922a3..c79cc152 100644 --- a/test/async_/test_models.py +++ b/test/async_/test_models.py @@ -219,7 +219,8 @@ async def test_refresh(): assert c.age == 20 - if adb.database_version.startswith("4"): + _db_version = await adb.database_version + if _db_version.startswith("4"): c = Customer2.inflate(999) else: c = Customer2.inflate("4:xxxxxx:999") diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py index 1bbac2a1..6bcb0364 100644 --- a/test/async_/test_relationships.py +++ b/test/async_/test_relationships.py @@ -97,7 +97,7 @@ async def test_either_direction_connect(): result, _ = await sakis.cypher( f"""MATCH (us), (them) - WHERE {adb.get_id_method()}(us)=$self and {adb.get_id_method()}(them)=$them + WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", {"them": rey.element_id}, ) diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py index 0f3beb8d..456e031a 100644 --- a/test/sync_/conftest.py +++ b/test/sync_/conftest.py @@ -40,7 +40,8 @@ def setup_neo4j_session(request, event_loop): db.cypher_query( "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" ) - if db.database_edition == "enterprise": + db_edition = db.database_edition + if db_edition == "enterprise": db.cypher_query("GRANT ROLE publisher TO troygreene") db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") diff --git a/test/sync_/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py index 6694ddfe..a9b9c131 100644 --- a/test/sync_/test_dbms_awareness.py +++ b/test/sync_/test_dbms_awareness.py @@ -1,17 +1,17 @@ from test._async_compat import mark_sync_test -from pytest import mark +import pytest from neomodel.sync_.core import db from neomodel.util import version_tag_to_integer -# TODO : This calling database_version should be async -@mark.skipif( - db.database_version != "5.7.0", reason="Testing a specific database version" -) +@mark_sync_test def test_version_awareness(): - assert db.database_version == "5.7.0" + db_version = db.database_version + if db_version != "5.7.0": + pytest.skip("Testing a specific database version") + assert db_version == "5.7.0" assert db.version_is_higher_than("5.7") assert db.version_is_higher_than("5.6.0") assert db.version_is_higher_than("5") @@ -22,7 +22,8 @@ def test_version_awareness(): @mark_sync_test def test_edition_awareness(): - if db.database_edition == "enterprise": + db_edition = db.database_edition + if db_edition == "enterprise": assert db.edition_is_enterprise() else: assert not db.edition_is_enterprise() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 791a1fa2..8539eedc 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -116,7 +116,8 @@ def test_simple_traverse_with_filter(): qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) - results = qb.build_ast()._execute() + _ast = qb.build_ast() + results = _ast._execute() assert qb._ast.lookup assert qb._ast.match @@ -144,14 +145,16 @@ def test_double_traverse(): @mark_sync_test def test_count(): Coffee(name="Nescafe Gold", price=99).save() - count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count() + ast = QueryBuilder(NodeSet(source=Coffee)).build_ast() + count = ast._count() assert count > 0 Coffee(name="Kawa", price=27).save() node_set = NodeSet(source=Coffee) node_set.skip = 1 node_set.limit = 1 - count = QueryBuilder(node_set).build_ast()._count() + ast = QueryBuilder(node_set).build_ast() + count = ast._count() assert count == 1 @@ -181,7 +184,7 @@ def test_slice(): # TODO : Make slice work with async # Doing await (Coffee.nodes.all())[1:] fetches without slicing - assert len(list(Coffee.nodes.all()[1:])) == 2 + assert len(list((Coffee.nodes)[1:])) == 2 assert len(list(Coffee.nodes.all()[:1])) == 1 assert isinstance(Coffee.nodes[1], Coffee) assert isinstance(Coffee.nodes[0], Coffee) @@ -199,20 +202,19 @@ def test_issue_208(): b.suppliers.connect(l, {"courier": "fedex"}) b.suppliers.connect(a, {"courier": "dhl"}) - # TODO : Find a way to not need the .all() here - # Note : Check AsyncTraversal match - assert len(b.suppliers.match(courier="fedex").all()) - assert len(b.suppliers.match(courier="dhl").all()) + assert len(b.suppliers.match(courier="fedex")) + assert len(b.suppliers.match(courier="dhl")) @mark_sync_test def test_issue_589(): node1 = Extension().save() node2 = Extension().save() + # TODO : ALso test await node1.extension.check_contains(node2) + # This is the way to pick only a single relationship using async + assert node2 not in node1.extension node1.extension.connect(node2) - # TODO : Find a way to not need the .all() here - # Note : Check AsyncRelationshipDefinition (parent of AsyncRelationshipTo / From) - assert node2 in node1.extension.all() + assert node2 in node1.extension @mark_sync_test @@ -220,10 +222,10 @@ def test_contains(): expensive = Coffee(price=1000, name="Pricey").save() asda = Coffee(name="Asda", price=1).save() - # TODO : Find a way to not need the .all() here - assert expensive in Coffee.nodes.filter(price__gt=999).all() - assert asda not in Coffee.nodes.filter(price__gt=999).all() + assert expensive in Coffee.nodes.filter(price__gt=999) + assert asda not in Coffee.nodes.filter(price__gt=999) + # TODO : Fix this test # bad value raises with raises(ValueError): 2 in Coffee.nodes @@ -285,23 +287,22 @@ def test_extra_filters(): c3 = Coffee(name="Japans finest", price=35, id_=3).save() c4 = Coffee(name="US extra-fine", price=None, id_=4).save() - # TODO : Remove some .all() when filter is updated - coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]).all() + coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]) assert len(coffees_5_10) == 2, "unexpected number of results" assert c1 in coffees_5_10, "doesnt contain 5 price coffee" assert c2 in coffees_5_10, "doesnt contain 10 price coffee" - finest_coffees = Coffee.nodes.filter(name__iendswith=" Finest").all() + finest_coffees = Coffee.nodes.filter(name__iendswith=" Finest") assert len(finest_coffees) == 3, "unexpected number of results" assert c1 in finest_coffees, "doesnt contain 1st finest coffee" assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" - unpriced_coffees = Coffee.nodes.filter(price__isnull=True).all() + unpriced_coffees = Coffee.nodes.filter(price__isnull=True) assert len(unpriced_coffees) == 1, "unexpected number of results" assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" - coffees_with_id_gte_3 = Coffee.nodes.filter(id___gte=3).all() + coffees_with_id_gte_3 = Coffee.nodes.filter(id___gte=3) assert len(coffees_with_id_gte_3) == 2, "unexpected number of results" assert c3 in coffees_with_id_gte_3 assert c4 in coffees_with_id_gte_3 @@ -310,7 +311,7 @@ def test_extra_filters(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - Coffee.nodes.filter(elementId="4:xxx:111").all() + Coffee.nodes.filter(elementId="4:xxx:111") def test_traversal_definition_keys_are_valid(): diff --git a/test/sync_/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py index 8bc2680d..32de5d9a 100644 --- a/test/sync_/test_migration_neo4j_5.py +++ b/test/sync_/test_migration_neo4j_5.py @@ -38,7 +38,8 @@ def test_read_elements_id(): # Validate id properties # Behaviour is dependent on Neo4j version - if db.database_version.startswith("4"): + db_version = db.database_version + if db_version.startswith("4"): # Nodes' ids assert lex_hives.id == int(lex_hives.element_id) assert lex_hives.id == (the_hives.released.single()).id diff --git a/test/sync_/test_models.py b/test/sync_/test_models.py index 89667f56..b0e9aec3 100644 --- a/test/sync_/test_models.py +++ b/test/sync_/test_models.py @@ -219,7 +219,8 @@ def test_refresh(): assert c.age == 20 - if db.database_version.startswith("4"): + _db_version = db.database_version + if _db_version.startswith("4"): c = Customer2.inflate(999) else: c = Customer2.inflate("4:xxxxxx:999") From bc8847920e961c7d320a965321c01bc364f03f9c Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 10:00:54 +0100 Subject: [PATCH 49/73] Revert pre-commit --- .pre-commit-config.yaml | 2 +- run-unasync.sh | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) delete mode 100644 run-unasync.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cabda9d..dd58a3c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: hooks: - id: unasync name: unasync - entry: bash run-unasync.sh + entry: bin/make-unasync language: system files: "^(neomodel/async_|test/async_)/.*" - repo: https://github.com/psf/black diff --git a/run-unasync.sh b/run-unasync.sh deleted file mode 100644 index 590c2620..00000000 --- a/run-unasync.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -source venv/bin/activate -bin/make-unasync \ No newline at end of file From aba3a02731563bd36d50238e1a8ddff778f9a3f5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 11:00:51 +0100 Subject: [PATCH 50/73] Fix tests --- neomodel/async_/core.py | 8 ---- neomodel/async_/match.py | 2 +- neomodel/async_/relationship_manager.py | 4 +- neomodel/sync_/core.py | 8 ---- test/async_/test_indexing.py | 5 +-- test/async_/test_issue283.py | 12 +++--- test/async_/test_match_api.py | 48 +++++++++++----------- test/async_/test_relationship_models.py | 5 ++- test/async_/test_relationships.py | 10 ++--- test/async_/test_relative_relationships.py | 2 +- test/async_/test_transactions.py | 2 +- test/sync_/test_indexing.py | 3 +- test/sync_/test_issue283.py | 12 +++--- test/sync_/test_match_api.py | 42 +++++++++---------- test/sync_/test_relationship_models.py | 5 ++- test/sync_/test_transactions.py | 2 +- 16 files changed, 75 insertions(+), 95 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 1e555a22..c61671bc 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -204,7 +204,6 @@ async def close_connection(self): await self.driver.close() self.driver = None - # TODO : Make this async and turn on muck-spreader @property async def database_version(self): if self._database_version is None: @@ -998,8 +997,6 @@ def wrapper(*args, **kwargs): return wrapper -# TODO : Either deprecate auto_install_labels -# Or make it work with async class NodeMeta(type): def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) @@ -1055,10 +1052,6 @@ def __new__(mcs, name, bases, namespace): cls.__label__ = namespace.get("__label__", name) cls.__optional_labels__ = namespace.get("__optional_labels__", []) - # TODO : See previous TODO comment - # if config.AUTO_INSTALL_LABELS: - # await install_labels(cls, quiet=False) - build_class_registry(cls) return cls @@ -1138,7 +1131,6 @@ def nodes(cls): return AsyncNodeSet(cls) - # TODO : Update places where element_id is expected to be an int (where id(n)=$element_id) @property def element_id(self): if hasattr(self, "element_id_property"): diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 76f058a1..9e7e5852 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -761,7 +761,7 @@ async def check_bool(self): :rtype: bool """ ast = await self.query_cls(self).build_ast() - _count = ast._count() + _count = await ast._count() return _count > 0 async def check_nonzero(self): diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index e188c15d..80113f94 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -345,13 +345,13 @@ def exclude(self, *args, **kwargs): """ return AsyncNodeSet(self._new_traversal()).exclude(*args, **kwargs) - def is_connected(self, node): + async def is_connected(self, node): """ Check if a node is connected with this relationship type :param node: :return: bool """ - return self._new_traversal().__contains__(node) + return await self._new_traversal().check_contains(node) async def single(self): """ diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index b7fdc2fb..8f38d913 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -204,7 +204,6 @@ def close_connection(self): self.driver.close() self.driver = None - # TODO : Make this async and turn on muck-spreader @property def database_version(self): if self._database_version is None: @@ -994,8 +993,6 @@ def wrapper(*args, **kwargs): return wrapper -# TODO : Either deprecate auto_install_labels -# Or make it work with async class NodeMeta(type): def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) @@ -1051,10 +1048,6 @@ def __new__(mcs, name, bases, namespace): cls.__label__ = namespace.get("__label__", name) cls.__optional_labels__ = namespace.get("__optional_labels__", []) - # TODO : See previous TODO comment - # if config.AUTO_INSTALL_LABELS: - # await install_labels(cls, quiet=False) - build_class_registry(cls) return cls @@ -1134,7 +1127,6 @@ def nodes(cls): return NodeSet(cls) - # TODO : Update places where element_id is expected to be an int (where id(n)=$element_id) @property def element_id(self): if hasattr(self, "element_id_property"): diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py index 177ec0a2..aec3efef 100644 --- a/test/async_/test_indexing.py +++ b/test/async_/test_indexing.py @@ -61,9 +61,8 @@ async def test_optional_properties_dont_get_indexed(): async def test_escaped_chars(): _name = "sarah:test" await Human(name=_name, age=3).save() - r = Human.nodes.filter(name=_name) - first_r = await r[0] - assert first_r.name == _name + r = await Human.nodes.filter(name=_name) + assert r[0].name == _name @mark_async_test diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index 1f62d80c..831f5631 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -118,7 +118,7 @@ async def test_automatic_result_resolution(): # If A is friends with B, then A's friends_with objects should be # TechnicalPerson (!NOT basePerson!) - assert type(await A.friends_with[0]) is TechnicalPerson + assert type((await A.friends_with)[0]) is TechnicalPerson await A.delete() await B.delete() @@ -230,13 +230,13 @@ async def test_validation_with_inheritance_from_db(): # This now means that friends_with of a TechnicalPerson can # either be TechnicalPerson or Pilot Person (!NOT basePerson!) - assert (type(await A.friends_with[0]) is TechnicalPerson) or ( - type(await A.friends_with[0]) is PilotPerson + assert (type((await A.friends_with)[0]) is TechnicalPerson) or ( + type((await A.friends_with)[0]) is PilotPerson ) - assert (type(await A.friends_with[1]) is TechnicalPerson) or ( - type(await A.friends_with[1]) is PilotPerson + assert (type((await A.friends_with)[1]) is TechnicalPerson) or ( + type((await A.friends_with)[1]) is PilotPerson ) - assert type(await D.friends_with[0]) is PilotPerson + assert type((await D.friends_with)[0]) is PilotPerson await A.delete() await B.delete() diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 88367dff..25beb266 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -126,7 +126,7 @@ async def test_simple_traverse_with_filter(): ) _ast = await qb.build_ast() - results = _ast._execute() + results = await _ast._execute() assert qb._ast.lookup assert qb._ast.match @@ -191,13 +191,12 @@ async def test_slice(): await Coffee(name="Britains finest").save() await Coffee(name="Japans finest").save() - # TODO : Make slice work with async - # Doing await (Coffee.nodes.all())[1:] fetches without slicing + # TODO : Branch this for sync to remove the extra brackets ? assert len(list((await Coffee.nodes)[1:])) == 2 - assert len(list(Coffee.nodes.all()[:1])) == 1 - assert isinstance(Coffee.nodes[1], Coffee) - assert isinstance(Coffee.nodes[0], Coffee) - assert len(list(Coffee.nodes.all()[1:2])) == 1 + assert len(list((await Coffee.nodes)[:1])) == 1 + assert isinstance((await Coffee.nodes)[1], Coffee) + assert isinstance((await Coffee.nodes)[0], Coffee) + assert len(list((await Coffee.nodes)[1:2])) == 1 @mark_async_test @@ -219,8 +218,6 @@ async def test_issue_208(): async def test_issue_589(): node1 = await Extension().save() node2 = await Extension().save() - # TODO : ALso test await node1.extension.check_contains(node2) - # This is the way to pick only a single relationship using async assert node2 not in await node1.extension await node1.extension.connect(node2) assert node2 in await node1.extension @@ -234,14 +231,15 @@ async def test_contains(): assert expensive in await Coffee.nodes.filter(price__gt=999) assert asda not in await Coffee.nodes.filter(price__gt=999) - # TODO : Fix this test + # TODO : Branch this for async => should be "2 in Coffee.nodes" + # Good example for documentation # bad value raises - with raises(ValueError): - 2 in await Coffee.nodes + with raises(ValueError, match=r"Expecting StructuredNode instance"): + await Coffee.nodes.check_contains(2) # unsaved - with raises(ValueError): - Coffee() in await Coffee.nodes + with raises(ValueError, match=r"Unsaved node"): + await Coffee.nodes.check_contains(Coffee()) @mark_async_test @@ -253,11 +251,11 @@ async def test_order_by(): c2 = await Coffee(name="Britains finest", price=10).save() c3 = await Coffee(name="Japans finest", price=35).save() - assert (await Coffee.nodes.order_by("price")[0]).price == 5 - assert (await Coffee.nodes.order_by("-price")[0]).price == 35 + # TODO : Branch this for sync to remove the extra brackets ? + assert ((await Coffee.nodes.order_by("price"))[0]).price == 5 + assert ((await Coffee.nodes.order_by("-price"))[0]).price == 35 - ns = await Coffee.nodes.order_by("-price") - # TODO : Method fails + ns = Coffee.nodes.order_by("-price") qb = await AsyncQueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) @@ -272,7 +270,7 @@ async def test_order_by(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - Coffee.nodes.order_by("id") + await Coffee.nodes.order_by("id") # Test order by on a relationship l = await Supplier(name="lidl2").save() @@ -280,7 +278,7 @@ async def test_order_by(): await l.coffees.connect(c2) await l.coffees.connect(c3) - ordered_n = [n for n in await l.coffees.order_by("name").all()] + ordered_n = [n for n in await l.coffees.order_by("name")] assert ordered_n[0] == c2 assert ordered_n[1] == c1 assert ordered_n[2] == c3 @@ -530,17 +528,17 @@ async def test_fetch_relations(): ) assert result[0][0] is None + # TODO : Branch the following two for sync to use len() and in instead of the dunder overrides # len() should only consider Suppliers - count = len( + count = ( await Supplier.nodes.filter(name="Sainsburys") .fetch_relations("coffees__species") - .all() + .get_len() ) assert count == 1 assert ( - tesco - in await Supplier.nodes.fetch_relations("coffees__species") + await Supplier.nodes.fetch_relations("coffees__species") .filter(name="Sainsburys") - .all() + .check_contains(tesco) ) diff --git a/test/async_/test_relationship_models.py b/test/async_/test_relationship_models.py index b360a881..1e8f632f 100644 --- a/test/async_/test_relationship_models.py +++ b/test/async_/test_relationship_models.py @@ -132,8 +132,9 @@ async def test_multiple_rels_exist_issue_223(): rel_b = await phill.hates.connect(ian, {"reason": "b"}) assert rel_a.element_id != rel_b.element_id - ian_a = await phill.hates.match(reason="a")[0] - ian_b = await phill.hates.match(reason="b")[0] + # TODO : Branch this for sync to remove extra brackets + ian_a = (await phill.hates.match(reason="a"))[0] + ian_b = (await phill.hates.match(reason="b"))[0] assert ian_a.element_id == ian_b.element_id diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py index 6bcb0364..35d2b95c 100644 --- a/test/async_/test_relationships.py +++ b/test/async_/test_relationships.py @@ -119,19 +119,19 @@ async def test_search_and_filter_and_exclude(): await fred.is_from.connect(zz) await fred.is_from.connect(zx) await fred.is_from.connect(zt) - result = fred.is_from.filter(code="ZX") + result = await fred.is_from.filter(code="ZX") assert result[0].code == "ZX" - result = fred.is_from.filter(code="ZY") + result = await fred.is_from.filter(code="ZY") assert result[0].code == "ZY" - result = fred.is_from.exclude(code="ZZ").exclude(code="ZY") + result = await fred.is_from.exclude(code="ZZ").exclude(code="ZY") assert result[0].code == "ZX" and len(result) == 1 - result = fred.is_from.exclude(Q(code__contains="Y")) + result = await fred.is_from.exclude(Q(code__contains="Y")) assert len(result) == 2 - result = fred.is_from.filter(Q(code__contains="Z")) + result = await fred.is_from.filter(Q(code__contains="Z")) assert len(result) == 3 diff --git a/test/async_/test_relative_relationships.py b/test/async_/test_relative_relationships.py index 371be944..7b283f84 100644 --- a/test/async_/test_relative_relationships.py +++ b/test/async_/test_relative_relationships.py @@ -21,4 +21,4 @@ async def test_relative_relationship(): # connecting an instance of the class defined above # the next statement will fail if there's a type mismatch await a.is_from.connect(c) - assert a.is_from.is_connected(c) + assert await a.is_from.is_connected(c) diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index 139b49f6..29a44893 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -121,7 +121,7 @@ async def in_a_tx(*names): await APerson(name=n).save() -# TODO : FIx this once decorator is fixed +# TODO : FIx this once in_a_tx is fixed @mark_async_test async def test_bookmark_transaction_decorator(): for p in await APerson.nodes: diff --git a/test/sync_/test_indexing.py b/test/sync_/test_indexing.py index c50a53f6..d253ffcd 100644 --- a/test/sync_/test_indexing.py +++ b/test/sync_/test_indexing.py @@ -57,8 +57,7 @@ def test_escaped_chars(): _name = "sarah:test" Human(name=_name, age=3).save() r = Human.nodes.filter(name=_name) - first_r = r[0] - assert first_r.name == _name + assert r[0].name == _name @mark_sync_test diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index 6fcc9b99..226e5ac7 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -112,7 +112,7 @@ def test_automatic_result_resolution(): # If A is friends with B, then A's friends_with objects should be # TechnicalPerson (!NOT basePerson!) - assert type(A.friends_with[0]) is TechnicalPerson + assert type((A.friends_with)[0]) is TechnicalPerson A.delete() B.delete() @@ -206,13 +206,13 @@ def test_validation_with_inheritance_from_db(): # This now means that friends_with of a TechnicalPerson can # either be TechnicalPerson or Pilot Person (!NOT basePerson!) - assert (type(A.friends_with[0]) is TechnicalPerson) or ( - type(A.friends_with[0]) is PilotPerson + assert (type((A.friends_with)[0]) is TechnicalPerson) or ( + type((A.friends_with)[0]) is PilotPerson ) - assert (type(A.friends_with[1]) is TechnicalPerson) or ( - type(A.friends_with[1]) is PilotPerson + assert (type((A.friends_with)[1]) is TechnicalPerson) or ( + type((A.friends_with)[1]) is PilotPerson ) - assert type(D.friends_with[0]) is PilotPerson + assert type((D.friends_with)[0]) is PilotPerson A.delete() B.delete() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 8539eedc..85f431bd 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -182,13 +182,12 @@ def test_slice(): Coffee(name="Britains finest").save() Coffee(name="Japans finest").save() - # TODO : Make slice work with async - # Doing await (Coffee.nodes.all())[1:] fetches without slicing + # TODO : Branch this for sync to remove the extra brackets ? assert len(list((Coffee.nodes)[1:])) == 2 - assert len(list(Coffee.nodes.all()[:1])) == 1 - assert isinstance(Coffee.nodes[1], Coffee) - assert isinstance(Coffee.nodes[0], Coffee) - assert len(list(Coffee.nodes.all()[1:2])) == 1 + assert len(list((Coffee.nodes)[:1])) == 1 + assert isinstance((Coffee.nodes)[1], Coffee) + assert isinstance((Coffee.nodes)[0], Coffee) + assert len(list((Coffee.nodes)[1:2])) == 1 @mark_sync_test @@ -210,8 +209,6 @@ def test_issue_208(): def test_issue_589(): node1 = Extension().save() node2 = Extension().save() - # TODO : ALso test await node1.extension.check_contains(node2) - # This is the way to pick only a single relationship using async assert node2 not in node1.extension node1.extension.connect(node2) assert node2 in node1.extension @@ -225,14 +222,15 @@ def test_contains(): assert expensive in Coffee.nodes.filter(price__gt=999) assert asda not in Coffee.nodes.filter(price__gt=999) - # TODO : Fix this test + # TODO : Branch this for async => should be "2 in Coffee.nodes" + # Good example for documentation # bad value raises - with raises(ValueError): - 2 in Coffee.nodes + with raises(ValueError, match=r"Expecting StructuredNode instance"): + Coffee.nodes.__contains__(2) # unsaved - with raises(ValueError): - Coffee() in Coffee.nodes + with raises(ValueError, match=r"Unsaved node"): + Coffee.nodes.__contains__(Coffee()) @mark_sync_test @@ -244,11 +242,11 @@ def test_order_by(): c2 = Coffee(name="Britains finest", price=10).save() c3 = Coffee(name="Japans finest", price=35).save() - assert (Coffee.nodes.order_by("price")[0]).price == 5 - assert (Coffee.nodes.order_by("-price")[0]).price == 35 + # TODO : Branch this for sync to remove the extra brackets ? + assert ((Coffee.nodes.order_by("price"))[0]).price == 5 + assert ((Coffee.nodes.order_by("-price"))[0]).price == 35 ns = Coffee.nodes.order_by("-price") - # TODO : Method fails qb = QueryBuilder(ns).build_ast() assert qb._ast.order_by ns = ns.order_by(None) @@ -271,7 +269,7 @@ def test_order_by(): l.coffees.connect(c2) l.coffees.connect(c3) - ordered_n = [n for n in l.coffees.order_by("name").all()] + ordered_n = [n for n in l.coffees.order_by("name")] assert ordered_n[0] == c2 assert ordered_n[1] == c1 assert ordered_n[2] == c3 @@ -519,17 +517,17 @@ def test_fetch_relations(): ) assert result[0][0] is None + # TODO : Branch the following two for sync to use len() and in instead of the dunder overrides # len() should only consider Suppliers - count = len( + count = ( Supplier.nodes.filter(name="Sainsburys") .fetch_relations("coffees__species") - .all() + .__len__() ) assert count == 1 assert ( - tesco - in Supplier.nodes.fetch_relations("coffees__species") + Supplier.nodes.fetch_relations("coffees__species") .filter(name="Sainsburys") - .all() + .__contains__(tesco) ) diff --git a/test/sync_/test_relationship_models.py b/test/sync_/test_relationship_models.py index 5b2e75d7..ffc3512f 100644 --- a/test/sync_/test_relationship_models.py +++ b/test/sync_/test_relationship_models.py @@ -130,8 +130,9 @@ def test_multiple_rels_exist_issue_223(): rel_b = phill.hates.connect(ian, {"reason": "b"}) assert rel_a.element_id != rel_b.element_id - ian_a = phill.hates.match(reason="a")[0] - ian_b = phill.hates.match(reason="b")[0] + # TODO : Branch this for sync to remove extra brackets + ian_a = (phill.hates.match(reason="a"))[0] + ian_b = (phill.hates.match(reason="b"))[0] assert ian_a.element_id == ian_b.element_id diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index 78ad4c58..4c5771f8 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -121,7 +121,7 @@ def in_a_tx(*names): APerson(name=n).save() -# TODO : FIx this once decorator is fixed +# TODO : FIx this once in_a_tx is fixed @mark_sync_test def test_bookmark_transaction_decorator(): for p in APerson.nodes: From a83ae8642d218cd8fc01ac1ef1be5b54c025a981 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 11:23:32 +0100 Subject: [PATCH 51/73] Move singleton import into root --- neomodel/__init__.py | 5 +++-- neomodel/async_/__init__.py | 1 - neomodel/sync_/__init__.py | 1 - test/async_/conftest.py | 3 +-- test/async_/test_cardinality.py | 2 +- test/async_/test_connection.py | 3 +-- test/async_/test_cypher.py | 3 +-- test/async_/test_database_management.py | 3 ++- test/async_/test_dbms_awareness.py | 2 +- test/async_/test_driver_options.py | 2 +- test/async_/test_indexing.py | 2 +- test/async_/test_issue283.py | 2 +- test/async_/test_label_drop.py | 3 +-- test/async_/test_label_install.py | 5 ++++- test/async_/test_match_api.py | 2 +- test/async_/test_migration_neo4j_5.py | 2 +- test/async_/test_models.py | 2 +- test/async_/test_multiprocessing.py | 3 +-- test/async_/test_paths.py | 2 +- test/async_/test_properties.py | 3 +-- test/async_/test_relationships.py | 2 +- test/async_/test_transactions.py | 3 +-- test/sync_/conftest.py | 3 +-- test/sync_/test_cardinality.py | 2 +- test/sync_/test_connection.py | 3 +-- test/sync_/test_cypher.py | 3 +-- test/sync_/test_database_management.py | 3 ++- test/sync_/test_dbms_awareness.py | 2 +- test/sync_/test_driver_options.py | 2 +- test/sync_/test_indexing.py | 3 +-- test/sync_/test_issue283.py | 2 +- test/sync_/test_label_drop.py | 3 +-- test/sync_/test_label_install.py | 5 ++++- test/sync_/test_match_api.py | 2 +- test/sync_/test_migration_neo4j_5.py | 2 +- test/sync_/test_models.py | 2 +- test/sync_/test_multiprocessing.py | 3 +-- test/sync_/test_paths.py | 2 +- test/sync_/test_properties.py | 3 +-- test/sync_/test_relationships.py | 2 +- test/sync_/test_transactions.py | 3 +-- 41 files changed, 49 insertions(+), 57 deletions(-) diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 399fb0f9..58b39178 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,5 +1,5 @@ # pep8: noqa -# TODO : Check imports here +# TODO : Check imports sync + async from neomodel.async_.cardinality import ( AsyncOne, AsyncOneOrMore, @@ -8,6 +8,7 @@ ) from neomodel.async_.core import ( AsyncStructuredNode, + adb, change_neo4j_password, clear_neo4j_database, drop_constraints, @@ -45,7 +46,7 @@ UniqueIdProperty, ) from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne -from neomodel.sync_.core import StructuredNode +from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.match import NodeSet, Traversal from neomodel.sync_.path import NeomodelPath from neomodel.sync_.property_manager import PropertyManager diff --git a/neomodel/async_/__init__.py b/neomodel/async_/__init__.py index f1d519e8..e69de29b 100644 --- a/neomodel/async_/__init__.py +++ b/neomodel/async_/__init__.py @@ -1 +0,0 @@ -# from neomodel.async_.core import adb diff --git a/neomodel/sync_/__init__.py b/neomodel/sync_/__init__.py index f1d519e8..e69de29b 100644 --- a/neomodel/sync_/__init__.py +++ b/neomodel/sync_/__init__.py @@ -1 +0,0 @@ -# from neomodel.async_.core import adb diff --git a/test/async_/conftest.py b/test/async_/conftest.py index 302b07f3..493ff12c 100644 --- a/test/async_/conftest.py +++ b/test/async_/conftest.py @@ -5,8 +5,7 @@ import pytest -from neomodel import config -from neomodel.async_.core import adb +from neomodel import adb, config @mark_async_session_auto_fixture diff --git a/test/async_/test_cardinality.py b/test/async_/test_cardinality.py index e72fa912..4ce02ad4 100644 --- a/test/async_/test_cardinality.py +++ b/test/async_/test_cardinality.py @@ -13,8 +13,8 @@ CardinalityViolation, IntegerProperty, StringProperty, + adb, ) -from neomodel.async_.core import adb class HairDryer(AsyncStructuredNode): diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py index 36c5922b..a2eded7d 100644 --- a/test/async_/test_connection.py +++ b/test/async_/test_connection.py @@ -6,8 +6,7 @@ from neo4j import AsyncDriver, AsyncGraphDatabase from neo4j.debug import watch -from neomodel import AsyncStructuredNode, StringProperty, config -from neomodel.async_.core import adb +from neomodel import AsyncStructuredNode, StringProperty, adb, config @mark_async_test diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index dc37a5b1..ac310307 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -6,8 +6,7 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import AsyncStructuredNode, StringProperty -from neomodel.async_.core import adb +from neomodel import AsyncStructuredNode, StringProperty, adb class User2(AsyncStructuredNode): diff --git a/test/async_/test_database_management.py b/test/async_/test_database_management.py index 68dcccc9..5159642a 100644 --- a/test/async_/test_database_management.py +++ b/test/async_/test_database_management.py @@ -9,8 +9,8 @@ AsyncStructuredRel, IntegerProperty, StringProperty, + adb, ) -from neomodel.async_.core import adb class City(AsyncStructuredNode): @@ -41,6 +41,7 @@ async def test_clear_database(): assert database_is_populated[0][0] is False + await adb.install_all_labels() indexes = await adb.list_indexes(exclude_token_lookup=True) constraints = await adb.list_constraints() assert len(indexes) > 0 diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py index 1c309a62..66a8fc5f 100644 --- a/test/async_/test_dbms_awareness.py +++ b/test/async_/test_dbms_awareness.py @@ -2,7 +2,7 @@ import pytest -from neomodel.async_.core import adb +from neomodel import adb from neomodel.util import version_tag_to_integer diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py index cd14c78f..df378f98 100644 --- a/test/async_/test_driver_options.py +++ b/test/async_/test_driver_options.py @@ -4,7 +4,7 @@ from neo4j.exceptions import ClientError from pytest import raises -from neomodel.async_.core import adb +from neomodel import adb from neomodel.exceptions import FeatureNotSupported diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py index aec3efef..9e0d8f37 100644 --- a/test/async_/test_indexing.py +++ b/test/async_/test_indexing.py @@ -8,8 +8,8 @@ IntegerProperty, StringProperty, UniqueProperty, + adb, ) -from neomodel.async_.core import adb from neomodel.exceptions import ConstraintValidationFailed diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index 831f5631..097f0d57 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -24,8 +24,8 @@ RelationshipClassRedefined, StringProperty, UniqueIdProperty, + adb, ) -from neomodel.async_.core import adb from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined try: diff --git a/test/async_/test_label_drop.py b/test/async_/test_label_drop.py index fa4b6106..3d64050b 100644 --- a/test/async_/test_label_drop.py +++ b/test/async_/test_label_drop.py @@ -2,8 +2,7 @@ from neo4j.exceptions import ClientError -from neomodel import AsyncStructuredNode, StringProperty -from neomodel.async_.core import adb +from neomodel import AsyncStructuredNode, StringProperty, adb class ConstraintAndIndex(AsyncStructuredNode): diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py index 6078c752..dc3e961d 100644 --- a/test/async_/test_label_install.py +++ b/test/async_/test_label_install.py @@ -8,8 +8,8 @@ AsyncStructuredRel, StringProperty, UniqueIdProperty, + adb, ) -from neomodel.async_.core import adb from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported @@ -47,6 +47,7 @@ class SomeNotUniqueNode(AsyncStructuredNode): @mark_async_test async def test_install_all(): await adb.drop_constraints() + await adb.drop_indexes() await adb.install_labels(AbstractNode) # run install all labels await adb.install_all_labels() @@ -66,6 +67,8 @@ async def test_install_all(): @mark_async_test async def test_install_label_twice(capsys): + await adb.drop_constraints() + await adb.drop_indexes() expected_std_out = ( "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" ) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 25beb266..3a259b0f 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -318,7 +318,7 @@ async def test_extra_filters(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - await Coffee.nodes.filter(elementId="4:xxx:111") + await Coffee.nodes.filter(elementId="4:xxx:111").all() def test_traversal_definition_keys_are_valid(): diff --git a/test/async_/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py index 346c694a..7b3397dd 100644 --- a/test/async_/test_migration_neo4j_5.py +++ b/test/async_/test_migration_neo4j_5.py @@ -8,8 +8,8 @@ AsyncStructuredRel, IntegerProperty, StringProperty, + adb, ) -from neomodel.async_.core import adb class Album(AsyncStructuredNode): diff --git a/test/async_/test_models.py b/test/async_/test_models.py index c79cc152..93bee6ab 100644 --- a/test/async_/test_models.py +++ b/test/async_/test_models.py @@ -11,8 +11,8 @@ DateProperty, IntegerProperty, StringProperty, + adb, ) -from neomodel.async_.core import adb from neomodel.exceptions import RequiredProperty, UniqueProperty diff --git a/test/async_/test_multiprocessing.py b/test/async_/test_multiprocessing.py index 101126e7..9bf46598 100644 --- a/test/async_/test_multiprocessing.py +++ b/test/async_/test_multiprocessing.py @@ -1,8 +1,7 @@ from multiprocessing.pool import ThreadPool as Pool from test._async_compat import mark_async_test -from neomodel import AsyncStructuredNode, StringProperty -from neomodel.async_.core import adb +from neomodel import AsyncStructuredNode, StringProperty, adb class ThingyMaBob(AsyncStructuredNode): diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py index 91120675..59a5e385 100644 --- a/test/async_/test_paths.py +++ b/test/async_/test_paths.py @@ -8,8 +8,8 @@ IntegerProperty, StringProperty, UniqueIdProperty, + adb, ) -from neomodel.async_.core import adb class PersonLivesInCity(AsyncStructuredRel): diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 9156250b..58db3047 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -4,8 +4,7 @@ from pytest import mark, raises from pytz import timezone -from neomodel import AsyncStructuredNode -from neomodel.async_.core import adb +from neomodel import AsyncStructuredNode, adb from neomodel.exceptions import ( DeflateError, InflateError, diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py index 35d2b95c..df50e701 100644 --- a/test/async_/test_relationships.py +++ b/test/async_/test_relationships.py @@ -12,8 +12,8 @@ IntegerProperty, Q, StringProperty, + adb, ) -from neomodel.async_.core import adb class PersonWithRels(AsyncStructuredNode): diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index 29a44893..d115f186 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -5,8 +5,7 @@ from neo4j.exceptions import ClientError, TransactionError from pytest import raises -from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty -from neomodel.async_.core import adb +from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty, adb class APerson(AsyncStructuredNode): diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py index 456e031a..735c9840 100644 --- a/test/sync_/conftest.py +++ b/test/sync_/conftest.py @@ -5,8 +5,7 @@ import pytest -from neomodel import config -from neomodel.sync_.core import db +from neomodel import config, db @mark_sync_session_auto_fixture diff --git a/test/sync_/test_cardinality.py b/test/sync_/test_cardinality.py index 2971fb33..9e83762c 100644 --- a/test/sync_/test_cardinality.py +++ b/test/sync_/test_cardinality.py @@ -13,8 +13,8 @@ StructuredNode, ZeroOrMore, ZeroOrOne, + db, ) -from neomodel.sync_.core import db class HairDryer(StructuredNode): diff --git a/test/sync_/test_connection.py b/test/sync_/test_connection.py index cc77df18..e7c0d7ce 100644 --- a/test/sync_/test_connection.py +++ b/test/sync_/test_connection.py @@ -6,8 +6,7 @@ from neo4j import Driver, GraphDatabase from neo4j.debug import watch -from neomodel import StringProperty, StructuredNode, config -from neomodel.sync_.core import db +from neomodel import StringProperty, StructuredNode, config, db @mark_sync_test diff --git a/test/sync_/test_cypher.py b/test/sync_/test_cypher.py index 46beed7e..ab8b6d65 100644 --- a/test/sync_/test_cypher.py +++ b/test/sync_/test_cypher.py @@ -6,8 +6,7 @@ from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNode -from neomodel.sync_.core import db +from neomodel import StringProperty, StructuredNode, db class User2(StructuredNode): diff --git a/test/sync_/test_database_management.py b/test/sync_/test_database_management.py index 82811b93..9f663994 100644 --- a/test/sync_/test_database_management.py +++ b/test/sync_/test_database_management.py @@ -9,8 +9,8 @@ StringProperty, StructuredNode, StructuredRel, + db, ) -from neomodel.sync_.core import db class City(StructuredNode): @@ -41,6 +41,7 @@ def test_clear_database(): assert database_is_populated[0][0] is False + db.install_all_labels() indexes = db.list_indexes(exclude_token_lookup=True) constraints = db.list_constraints() assert len(indexes) > 0 diff --git a/test/sync_/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py index a9b9c131..f0f7fb68 100644 --- a/test/sync_/test_dbms_awareness.py +++ b/test/sync_/test_dbms_awareness.py @@ -2,7 +2,7 @@ import pytest -from neomodel.sync_.core import db +from neomodel import db from neomodel.util import version_tag_to_integer diff --git a/test/sync_/test_driver_options.py b/test/sync_/test_driver_options.py index c4deb59c..f244d174 100644 --- a/test/sync_/test_driver_options.py +++ b/test/sync_/test_driver_options.py @@ -4,8 +4,8 @@ from neo4j.exceptions import ClientError from pytest import raises +from neomodel import db from neomodel.exceptions import FeatureNotSupported -from neomodel.sync_.core import db @mark_sync_test diff --git a/test/sync_/test_indexing.py b/test/sync_/test_indexing.py index d253ffcd..f39c22ef 100644 --- a/test/sync_/test_indexing.py +++ b/test/sync_/test_indexing.py @@ -3,9 +3,8 @@ import pytest from pytest import raises -from neomodel import IntegerProperty, StringProperty, StructuredNode, UniqueProperty +from neomodel import IntegerProperty, StringProperty, StructuredNode, UniqueProperty, db from neomodel.exceptions import ConstraintValidationFailed -from neomodel.sync_.core import db class Human(StructuredNode): diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index 226e5ac7..877efe0f 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -24,9 +24,9 @@ StructuredNode, StructuredRel, UniqueIdProperty, + db, ) from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined -from neomodel.sync_.core import db try: basestring diff --git a/test/sync_/test_label_drop.py b/test/sync_/test_label_drop.py index 55db5fec..e4834817 100644 --- a/test/sync_/test_label_drop.py +++ b/test/sync_/test_label_drop.py @@ -2,8 +2,7 @@ from neo4j.exceptions import ClientError -from neomodel import StringProperty, StructuredNode -from neomodel.sync_.core import db +from neomodel import StringProperty, StructuredNode, db class ConstraintAndIndex(StructuredNode): diff --git a/test/sync_/test_label_install.py b/test/sync_/test_label_install.py index 74235b74..e1d60636 100644 --- a/test/sync_/test_label_install.py +++ b/test/sync_/test_label_install.py @@ -8,9 +8,9 @@ StructuredNode, StructuredRel, UniqueIdProperty, + db, ) from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported -from neomodel.sync_.core import db class NodeWithIndex(StructuredNode): @@ -47,6 +47,7 @@ class SomeNotUniqueNode(StructuredNode): @mark_sync_test def test_install_all(): db.drop_constraints() + db.drop_indexes() db.install_labels(AbstractNode) # run install all labels db.install_all_labels() @@ -66,6 +67,8 @@ def test_install_all(): @mark_sync_test def test_install_label_twice(capsys): + db.drop_constraints() + db.drop_indexes() expected_std_out = ( "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" ) diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 85f431bd..8a92304e 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -309,7 +309,7 @@ def test_extra_filters(): ValueError, match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", ): - Coffee.nodes.filter(elementId="4:xxx:111") + Coffee.nodes.filter(elementId="4:xxx:111").all() def test_traversal_definition_keys_are_valid(): diff --git a/test/sync_/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py index 32de5d9a..10b2bac0 100644 --- a/test/sync_/test_migration_neo4j_5.py +++ b/test/sync_/test_migration_neo4j_5.py @@ -8,8 +8,8 @@ StringProperty, StructuredNode, StructuredRel, + db, ) -from neomodel.sync_.core import db class Album(StructuredNode): diff --git a/test/sync_/test_models.py b/test/sync_/test_models.py index b0e9aec3..13065c3b 100644 --- a/test/sync_/test_models.py +++ b/test/sync_/test_models.py @@ -11,9 +11,9 @@ StringProperty, StructuredNode, StructuredRel, + db, ) from neomodel.exceptions import RequiredProperty, UniqueProperty -from neomodel.sync_.core import db class User(StructuredNode): diff --git a/test/sync_/test_multiprocessing.py b/test/sync_/test_multiprocessing.py index 861b0af2..2d9167f9 100644 --- a/test/sync_/test_multiprocessing.py +++ b/test/sync_/test_multiprocessing.py @@ -1,8 +1,7 @@ from multiprocessing.pool import ThreadPool as Pool from test._async_compat import mark_sync_test -from neomodel import StringProperty, StructuredNode -from neomodel.sync_.core import db +from neomodel import StringProperty, StructuredNode, db class ThingyMaBob(StructuredNode): diff --git a/test/sync_/test_paths.py b/test/sync_/test_paths.py index b8f325f8..8e0ccf90 100644 --- a/test/sync_/test_paths.py +++ b/test/sync_/test_paths.py @@ -8,8 +8,8 @@ StructuredNode, StructuredRel, UniqueIdProperty, + db, ) -from neomodel.sync_.core import db class PersonLivesInCity(StructuredRel): diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 581924aa..2f5a444e 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -4,7 +4,7 @@ from pytest import mark, raises from pytz import timezone -from neomodel import StructuredNode +from neomodel import StructuredNode, db from neomodel.exceptions import ( DeflateError, InflateError, @@ -24,7 +24,6 @@ StringProperty, UniqueIdProperty, ) -from neomodel.sync_.core import db from neomodel.util import _get_node_properties diff --git a/test/sync_/test_relationships.py b/test/sync_/test_relationships.py index 13ca9295..8374935f 100644 --- a/test/sync_/test_relationships.py +++ b/test/sync_/test_relationships.py @@ -12,8 +12,8 @@ StringProperty, StructuredNode, StructuredRel, + db, ) -from neomodel.sync_.core import db class PersonWithRels(StructuredNode): diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index 4c5771f8..f9c1b2b6 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -5,8 +5,7 @@ from neo4j.exceptions import ClientError, TransactionError from pytest import raises -from neomodel import StringProperty, StructuredNode, UniqueProperty -from neomodel.sync_.core import db +from neomodel import StringProperty, StructuredNode, UniqueProperty, db class APerson(StructuredNode): From a62c054a859ff8283d62e1debabd408b9ce689dd Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 11:38:59 +0100 Subject: [PATCH 52/73] Branch tests async/sync --- neomodel/_async_compat/util.py | 9 +++ test/async_/test_match_api.py | 78 ++++++++++++++++--------- test/async_/test_relationship_models.py | 10 +++- test/sync_/test_match_api.py | 78 ++++++++++++++++--------- test/sync_/test_relationship_models.py | 10 +++- 5 files changed, 127 insertions(+), 58 deletions(-) create mode 100644 neomodel/_async_compat/util.py diff --git a/neomodel/_async_compat/util.py b/neomodel/_async_compat/util.py new file mode 100644 index 00000000..4868c3ba --- /dev/null +++ b/neomodel/_async_compat/util.py @@ -0,0 +1,9 @@ +import typing as t + + +class AsyncUtil: + is_async_code: t.ClassVar = True + + +class Util: + is_async_code: t.ClassVar = False diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 3a259b0f..eb2e40c8 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -14,6 +14,7 @@ Q, StringProperty, ) +from neomodel._async_compat.util import AsyncUtil from neomodel.async_.match import ( AsyncNodeSet, AsyncQueryBuilder, @@ -191,12 +192,19 @@ async def test_slice(): await Coffee(name="Britains finest").save() await Coffee(name="Japans finest").save() - # TODO : Branch this for sync to remove the extra brackets ? - assert len(list((await Coffee.nodes)[1:])) == 2 - assert len(list((await Coffee.nodes)[:1])) == 1 - assert isinstance((await Coffee.nodes)[1], Coffee) - assert isinstance((await Coffee.nodes)[0], Coffee) - assert len(list((await Coffee.nodes)[1:2])) == 1 + # Branching tests because async needs extra brackets + if AsyncUtil.is_async_code: + assert len(list((await Coffee.nodes)[1:])) == 2 + assert len(list((await Coffee.nodes)[:1])) == 1 + assert isinstance((await Coffee.nodes)[1], Coffee) + assert isinstance((await Coffee.nodes)[0], Coffee) + assert len(list((await Coffee.nodes)[1:2])) == 1 + else: + assert len(list(Coffee.nodes[1:])) == 2 + assert len(list(Coffee.nodes[:1])) == 1 + assert isinstance(Coffee.nodes[1], Coffee) + assert isinstance(Coffee.nodes[0], Coffee) + assert len(list(Coffee.nodes[1:2])) == 1 @mark_async_test @@ -231,15 +239,20 @@ async def test_contains(): assert expensive in await Coffee.nodes.filter(price__gt=999) assert asda not in await Coffee.nodes.filter(price__gt=999) - # TODO : Branch this for async => should be "2 in Coffee.nodes" - # Good example for documentation + # TODO : Good example for documentation # bad value raises with raises(ValueError, match=r"Expecting StructuredNode instance"): - await Coffee.nodes.check_contains(2) + if AsyncUtil.is_async_code: + assert await Coffee.nodes.check_contains(2) + else: + assert 2 in Coffee.nodes # unsaved with raises(ValueError, match=r"Unsaved node"): - await Coffee.nodes.check_contains(Coffee()) + if AsyncUtil.is_async_code: + assert await Coffee.nodes.check_contains(Coffee()) + else: + assert Coffee() in Coffee.nodes @mark_async_test @@ -251,9 +264,12 @@ async def test_order_by(): c2 = await Coffee(name="Britains finest", price=10).save() c3 = await Coffee(name="Japans finest", price=35).save() - # TODO : Branch this for sync to remove the extra brackets ? - assert ((await Coffee.nodes.order_by("price"))[0]).price == 5 - assert ((await Coffee.nodes.order_by("-price"))[0]).price == 35 + if AsyncUtil.is_async_code: + assert ((await Coffee.nodes.order_by("price"))[0]).price == 5 + assert ((await Coffee.nodes.order_by("-price"))[0]).price == 35 + else: + assert (Coffee.nodes.order_by("price")[0]).price == 5 + assert (Coffee.nodes.order_by("-price")[0]).price == 35 ns = Coffee.nodes.order_by("-price") qb = await AsyncQueryBuilder(ns).build_ast() @@ -528,17 +544,27 @@ async def test_fetch_relations(): ) assert result[0][0] is None - # TODO : Branch the following two for sync to use len() and in instead of the dunder overrides - # len() should only consider Suppliers - count = ( - await Supplier.nodes.filter(name="Sainsburys") - .fetch_relations("coffees__species") - .get_len() - ) - assert count == 1 + if AsyncUtil.is_async_code: + count = ( + await Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .get_len() + ) + assert count == 1 - assert ( - await Supplier.nodes.fetch_relations("coffees__species") - .filter(name="Sainsburys") - .check_contains(tesco) - ) + assert ( + await Supplier.nodes.fetch_relations("coffees__species") + .filter(name="Sainsburys") + .check_contains(tesco) + ) + else: + count = len( + Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert count == 1 + + assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( + name="Sainsburys" + ) diff --git a/test/async_/test_relationship_models.py b/test/async_/test_relationship_models.py index 1e8f632f..95fe4714 100644 --- a/test/async_/test_relationship_models.py +++ b/test/async_/test_relationship_models.py @@ -13,6 +13,7 @@ DeflateError, StringProperty, ) +from neomodel._async_compat.util import AsyncUtil HOOKS_CALLED = {"pre_save": 0, "post_save": 0} @@ -132,9 +133,12 @@ async def test_multiple_rels_exist_issue_223(): rel_b = await phill.hates.connect(ian, {"reason": "b"}) assert rel_a.element_id != rel_b.element_id - # TODO : Branch this for sync to remove extra brackets - ian_a = (await phill.hates.match(reason="a"))[0] - ian_b = (await phill.hates.match(reason="b"))[0] + if AsyncUtil.is_async_code: + ian_a = (await phill.hates.match(reason="a"))[0] + ian_b = (await phill.hates.match(reason="b"))[0] + else: + ian_a = phill.hates.match(reason="a")[0] + ian_b = phill.hates.match(reason="b")[0] assert ian_a.element_id == ian_b.element_id diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 8a92304e..bea54d06 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -14,6 +14,7 @@ StructuredNode, StructuredRel, ) +from neomodel._async_compat.util import Util from neomodel.exceptions import MultipleNodesReturned from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal @@ -182,12 +183,19 @@ def test_slice(): Coffee(name="Britains finest").save() Coffee(name="Japans finest").save() - # TODO : Branch this for sync to remove the extra brackets ? - assert len(list((Coffee.nodes)[1:])) == 2 - assert len(list((Coffee.nodes)[:1])) == 1 - assert isinstance((Coffee.nodes)[1], Coffee) - assert isinstance((Coffee.nodes)[0], Coffee) - assert len(list((Coffee.nodes)[1:2])) == 1 + # Branching tests because async needs extra brackets + if Util.is_async_code: + assert len(list((Coffee.nodes)[1:])) == 2 + assert len(list((Coffee.nodes)[:1])) == 1 + assert isinstance((Coffee.nodes)[1], Coffee) + assert isinstance((Coffee.nodes)[0], Coffee) + assert len(list((Coffee.nodes)[1:2])) == 1 + else: + assert len(list(Coffee.nodes[1:])) == 2 + assert len(list(Coffee.nodes[:1])) == 1 + assert isinstance(Coffee.nodes[1], Coffee) + assert isinstance(Coffee.nodes[0], Coffee) + assert len(list(Coffee.nodes[1:2])) == 1 @mark_sync_test @@ -222,15 +230,20 @@ def test_contains(): assert expensive in Coffee.nodes.filter(price__gt=999) assert asda not in Coffee.nodes.filter(price__gt=999) - # TODO : Branch this for async => should be "2 in Coffee.nodes" - # Good example for documentation + # TODO : Good example for documentation # bad value raises with raises(ValueError, match=r"Expecting StructuredNode instance"): - Coffee.nodes.__contains__(2) + if Util.is_async_code: + assert Coffee.nodes.__contains__(2) + else: + assert 2 in Coffee.nodes # unsaved with raises(ValueError, match=r"Unsaved node"): - Coffee.nodes.__contains__(Coffee()) + if Util.is_async_code: + assert Coffee.nodes.__contains__(Coffee()) + else: + assert Coffee() in Coffee.nodes @mark_sync_test @@ -242,9 +255,12 @@ def test_order_by(): c2 = Coffee(name="Britains finest", price=10).save() c3 = Coffee(name="Japans finest", price=35).save() - # TODO : Branch this for sync to remove the extra brackets ? - assert ((Coffee.nodes.order_by("price"))[0]).price == 5 - assert ((Coffee.nodes.order_by("-price"))[0]).price == 35 + if Util.is_async_code: + assert ((Coffee.nodes.order_by("price"))[0]).price == 5 + assert ((Coffee.nodes.order_by("-price"))[0]).price == 35 + else: + assert (Coffee.nodes.order_by("price")[0]).price == 5 + assert (Coffee.nodes.order_by("-price")[0]).price == 35 ns = Coffee.nodes.order_by("-price") qb = QueryBuilder(ns).build_ast() @@ -517,17 +533,27 @@ def test_fetch_relations(): ) assert result[0][0] is None - # TODO : Branch the following two for sync to use len() and in instead of the dunder overrides - # len() should only consider Suppliers - count = ( - Supplier.nodes.filter(name="Sainsburys") - .fetch_relations("coffees__species") - .__len__() - ) - assert count == 1 + if Util.is_async_code: + count = ( + Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .__len__() + ) + assert count == 1 - assert ( - Supplier.nodes.fetch_relations("coffees__species") - .filter(name="Sainsburys") - .__contains__(tesco) - ) + assert ( + Supplier.nodes.fetch_relations("coffees__species") + .filter(name="Sainsburys") + .__contains__(tesco) + ) + else: + count = len( + Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert count == 1 + + assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( + name="Sainsburys" + ) diff --git a/test/sync_/test_relationship_models.py b/test/sync_/test_relationship_models.py index ffc3512f..837f53d6 100644 --- a/test/sync_/test_relationship_models.py +++ b/test/sync_/test_relationship_models.py @@ -13,6 +13,7 @@ StructuredNode, StructuredRel, ) +from neomodel._async_compat.util import Util HOOKS_CALLED = {"pre_save": 0, "post_save": 0} @@ -130,9 +131,12 @@ def test_multiple_rels_exist_issue_223(): rel_b = phill.hates.connect(ian, {"reason": "b"}) assert rel_a.element_id != rel_b.element_id - # TODO : Branch this for sync to remove extra brackets - ian_a = (phill.hates.match(reason="a"))[0] - ian_b = (phill.hates.match(reason="b"))[0] + if Util.is_async_code: + ian_a = (phill.hates.match(reason="a"))[0] + ian_b = (phill.hates.match(reason="b"))[0] + else: + ian_a = phill.hates.match(reason="a")[0] + ian_b = phill.hates.match(reason="b")[0] assert ian_a.element_id == ian_b.element_id From dd5c3735fe9de78e81b1ba0023711a8b909e47cd Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 13:38:37 +0100 Subject: [PATCH 53/73] Update doc --- README.md | 43 ++++++++++++++++++++++++++++++++-- doc/source/configuration.rst | 8 +++---- doc/source/cypher.rst | 6 ++--- doc/source/getting_started.rst | 31 +++++++++++++++++++++--- doc/source/index.rst | 15 ++++++++++-- doc/source/transactions.rst | 6 ++--- neomodel/__init__.py | 24 +++++++++---------- pyproject.toml | 1 - test/async_/test_batch.py | 6 ++++- test/async_/test_match_api.py | 1 - test/sync_/test_batch.py | 6 ++++- test/sync_/test_match_api.py | 1 - 12 files changed, 114 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 5d4566a5..9fe5e06b 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ GitHub repo found at . # Documentation -(Needs an update, but) Available on +Available on [readthedocs](http://neomodel.readthedocs.org). # Upcoming breaking changes notice - \>=5.3 @@ -47,7 +47,7 @@ support for Python 3.12. Another source of upcoming breaking changes is the addition async support to neomodel. No date is set yet, but the work has progressed a lot in the past weeks ; -and it will be part of a major release (potentially 6.0 to avoid misunderstandings). +and it will be part of a major release. You can see the progress in [this branch](https://github.com/neo4j-contrib/neomodel/tree/task/async). Finally, we are looking at refactoring some standalone methods into the @@ -112,3 +112,42 @@ against all supported Python interpreters and neo4j versions: : # in the project's root folder: $ sh ./tests-with-docker-compose.sh + +## Developing with async + +### Transpiling async -> sync + +We use [this great library](https://github.com/python-trio/unasync) to automatically transpile async code into its sync version. + +In other words, when contributing to neomodel, only update the `async` code in `neomodel/async_`, then run : : + + bin/make-unasync + isort . + black . + +Note that you can also use the pre-commit hooks for this. + +### Specific async/sync code +This transpiling script mainly does two things : + +- It removes the await keywords, and the Async prefixes in class names +- It does some specific replacements, like `adb`->`db`, `mark_async_test`->`mark_sync_test` + +It might be that your code should only be run for `async`, or `sync` ; or you want different stubs to be run for `async` vs `sync`. +You can use the following utility function for this - taken from the official [Neo4j python driver code](https://github.com/neo4j/neo4j-python-driver) : + + # neomodel/async_/core.py + from neomodel._async_compat.util import AsyncUtil + + # AsyncUtil.is_async_code is always True + if AsyncUtil.is_async_code: + # Specific async code + # This one gets run when in async mode + assert await Coffee.nodes.check_contains(2) + else: + # Specific sync code + # This one gest run when in sync mode + assert 2 in Coffee.nodes + +You can check [test_match_api](test/async_/test_match_api.py) for some good examples, and how it's transpiled into sync. + diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index ba590be2..3946ec98 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -59,14 +59,14 @@ Note that you have to manage the driver's lifecycle yourself. However, everything else is still handled by neomodel : sessions, transactions, etc... -NB : Only the synchronous driver will work in this way. The asynchronous driver is not supported yet. +NB : Only the synchronous driver will work in this way. See the next section for the preferred method, and how to pass an async driver instance. Change/Close the connection --------------------------- Optionally, you can change the connection at any time by calling ``set_connection``:: - from neomodel.sync_.core import db + from neomodel import db # Using URL - auto-managed db.set_connection(url='bolt://neo4j:neo4j@localhost:7687') @@ -78,7 +78,7 @@ The new connection url will be applied to the current thread or process. Since Neo4j version 5, driver auto-close is deprecated. Make sure to close the connection anytime you want to replace it, as well as at the end of your application's lifecycle by calling ``close_connection``:: - from neomodel.sync_.core import db + from neomodel import db db.close_connection() # If you then want a new connection @@ -140,7 +140,7 @@ Or for an entire 'schema' :: # ... .. note:: - config.AUTO_INSTALL_LABELS has been removed from neomodel in version 6.0 + config.AUTO_INSTALL_LABELS has been removed from neomodel in version 5.3 Require timezones on DateTimeProperty ------------------------------------- diff --git a/doc/source/cypher.rst b/doc/source/cypher.rst index ed8f422c..f8c7ccaf 100644 --- a/doc/source/cypher.rst +++ b/doc/source/cypher.rst @@ -19,7 +19,7 @@ Stand alone Outside of a `StructuredNode`:: # for standalone queries - from neomodel.sync_.core import db + from neomodel import db results, meta = db.cypher_query(query, params, resolve_objects=True) The ``resolve_objects`` parameter automatically inflates the returned nodes to their defined classes (this is turned **off** by default). See :ref:`automatic_class_resolution` for details and possible pitfalls. @@ -40,7 +40,7 @@ First, you need to install pandas by yourself. We do not include it by default t You can use the `pandas` integration to return a `DataFrame` or `Series` object:: - from neomodel.sync_.core import db + from neomodel import db from neomodel.integration.pandas import to_dataframe, to_series df = to_dataframe(db.cypher_query("MATCH (a:Person) RETURN a.name AS name, a.born AS born")) @@ -59,7 +59,7 @@ First, you need to install numpy by yourself. We do not include it by default to You can use the `numpy` integration to return a `ndarray` object:: - from neomodel.sync_.core import db + from neomodel import db from neomodel.integration.numpy import to_ndarray array = to_ndarray(db.cypher_query("MATCH (a:Person) RETURN a.name AS name, a.born AS born")) diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 6aa8a421..7be18428 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -22,7 +22,7 @@ Querying the graph neomodel is mainly used as an OGM (see next section), but you can also use it for direct Cypher queries : :: - from neomodel.sync_.core import db + from neomodel import db results, meta = db.cypher_query("RETURN 'Hello World' as message") @@ -264,7 +264,7 @@ Async neomodel neomodel supports asynchronous operations using the async support of neo4j driver. The examples below take a few of the above examples, but rewritten for async:: - from neomodel.async_.core import adb + from neomodel import adb results, meta = await adb.cypher_query("RETURN 'Hello World' as message") OGM with async :: @@ -282,9 +282,34 @@ OGM with async :: # Operations that interact with the database are now async # Return all nodes - all_nodes = await Country.nodes.all() + # Note that the nodes object is awaitable as is + all_nodes = await Country.nodes # Relationships germany = await Country(code='DE').save() await jim.country.connect(germany) +Most _dunder_ methods for nodes and relationships had to be overriden to support async operations. The following methods are supported :: + + # Examples below are taken from the various tests. Please check them for more examples. + # Length + dogs_bonanza = await Dog.nodes.get_len() + # Sync equivalent - __len__ + dogs_bonanza = len(Dog.nodes) + + # Existence + assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() + # Sync equivalent - __bool__ + assert not Customer.nodes.filter(email="jim7@aol.com") + # Also works for check_nonzero => __nonzero__ + + # Contains + assert await Coffee.nodes.check_contains(aCoffeeNode) + # Sync equivalent - __contains__ + assert aCoffeeNode in Coffee.nodes + + # Get item + assert len(list((await Coffee.nodes)[1:])) == 2 + # Sync equivalent - __getitem__ + assert len(list(Coffee.nodes[1:])) == 2 + diff --git a/doc/source/index.rst b/doc/source/index.rst index b4e6f1f7..1338ce27 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -43,13 +43,24 @@ To install from github:: .. note:: - **Breaking changes in 6.0** + **Breaking changes in 5.3** Introducing support for asynchronous programming to neomodel required to introduce some breaking changes: - - Replace `from neomodel import db` with `from neomodel.sync_.core import db` - config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` (:ref:`neomodel_install_labels`) command instead. + **Deprecations in 5.3** + + - Some standalone methods are moved into the Database() class and will be removed in a future release : + - change_neo4j_password + - clear_neo4j_database + - drop_constraints + - drop_indexes + - remove_all_labels + - install_labels + - install_all_labels + - Additionally, to call these methods with async, use the ones in the AsyncDatabase() _adb_ singleton. + Contents ======== diff --git a/doc/source/transactions.rst b/doc/source/transactions.rst index 1ca80b08..dfa97ee6 100644 --- a/doc/source/transactions.rst +++ b/doc/source/transactions.rst @@ -13,7 +13,7 @@ Basic usage Transactions can be used via a context manager:: - from neomodel.sync_.core import db + from neomodel import db with db.transaction: Person(name='Bob').save() @@ -171,7 +171,7 @@ Impersonation Impersonation (`see Neo4j driver documentation ``) can be enabled via a context manager:: - from neomodel.sync_.core import db + from neomodel import db with db.impersonate(user="writeuser"): Person(name='Bob').save() @@ -186,7 +186,7 @@ or as a function decorator:: This can be mixed with other context manager like transactions:: - from neomodel.sync_.core import db + from neomodel import db @db.impersonate(user="tempuser") # Both transactions will be run as the same impersonated user diff --git a/neomodel/__init__.py b/neomodel/__init__.py index 58b39178..e10f571f 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -6,17 +6,7 @@ AsyncZeroOrMore, AsyncZeroOrOne, ) -from neomodel.async_.core import ( - AsyncStructuredNode, - adb, - change_neo4j_password, - clear_neo4j_database, - drop_constraints, - drop_indexes, - install_all_labels, - install_labels, - remove_all_labels, -) +from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.match import AsyncNodeSet, AsyncTraversal from neomodel.async_.path import AsyncNeomodelPath from neomodel.async_.relationship import AsyncStructuredRel @@ -46,7 +36,17 @@ UniqueIdProperty, ) from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne -from neomodel.sync_.core import StructuredNode, db +from neomodel.sync_.core import ( + StructuredNode, + change_neo4j_password, + clear_neo4j_database, + db, + drop_constraints, + drop_indexes, + install_all_labels, + install_labels, + remove_all_labels, +) from neomodel.sync_.match import NodeSet, Traversal from neomodel.sync_.path import NeomodelPath from neomodel.sync_.property_manager import PropertyManager diff --git a/pyproject.toml b/pyproject.toml index 1d9dcc67..1ac483bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ authors = [ maintainers = [ {name = "Marius Conjeaud", email = "marius.conjeaud@outlook.com"}, {name = "Athanasios Anastasiou", email = "athanastasiou@gmail.com"}, - {name = "Cristina Escalante"}, ] description = "An object mapper for the neo4j graph database." readme = "README.md" diff --git a/test/async_/test_batch.py b/test/async_/test_batch.py index a1d86e21..653dce0d 100644 --- a/test/async_/test_batch.py +++ b/test/async_/test_batch.py @@ -11,6 +11,7 @@ UniqueIdProperty, config, ) +from neomodel._async_compat.util import AsyncUtil from neomodel.exceptions import DeflateError, UniqueProperty config.AUTO_INSTALL_LABELS = True @@ -107,7 +108,10 @@ async def test_batch_index_violation(): ) # not found - assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() + if AsyncUtil.is_async_code: + assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() + else: + assert not Customer.nodes.filter(email="jim7@aol.com") class Dog(AsyncStructuredNode): diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index eb2e40c8..87c1ca8e 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -239,7 +239,6 @@ async def test_contains(): assert expensive in await Coffee.nodes.filter(price__gt=999) assert asda not in await Coffee.nodes.filter(price__gt=999) - # TODO : Good example for documentation # bad value raises with raises(ValueError, match=r"Expecting StructuredNode instance"): if AsyncUtil.is_async_code: diff --git a/test/sync_/test_batch.py b/test/sync_/test_batch.py index 3eabe65e..80812d31 100644 --- a/test/sync_/test_batch.py +++ b/test/sync_/test_batch.py @@ -11,6 +11,7 @@ UniqueIdProperty, config, ) +from neomodel._async_compat.util import Util from neomodel.exceptions import DeflateError, UniqueProperty config.AUTO_INSTALL_LABELS = True @@ -103,7 +104,10 @@ def test_batch_index_violation(): ) # not found - assert not Customer.nodes.filter(email="jim7@aol.com").__bool__() + if Util.is_async_code: + assert not Customer.nodes.filter(email="jim7@aol.com").__bool__() + else: + assert not Customer.nodes.filter(email="jim7@aol.com") class Dog(StructuredNode): diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index bea54d06..fe63badd 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -230,7 +230,6 @@ def test_contains(): assert expensive in Coffee.nodes.filter(price__gt=999) assert asda not in Coffee.nodes.filter(price__gt=999) - # TODO : Good example for documentation # bad value raises with raises(ValueError, match=r"Expecting StructuredNode instance"): if Util.is_async_code: From 2d7028b313cdedec57386c5937591d73b74024b2 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Sat, 16 Mar 2024 14:09:17 +0100 Subject: [PATCH 54/73] Remove flay type hint --- neomodel/async_/core.py | 4 ++-- neomodel/sync_/core.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index c61671bc..f5fe628e 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -418,7 +418,7 @@ async def cypher_query( handle_unique=True, retry_on_session_expire=False, resolve_objects=False, - ) -> Tuple[list[list], Tuple[str, ...]]: + ): """ Runs a query on the database and returns a list of results and their headers. @@ -471,7 +471,7 @@ async def _run_cypher_query( handle_unique, retry_on_session_expire, resolve_objects, - ) -> Tuple[list[list], Tuple[str, ...]]: + ): try: # Retrieve the data start = time.time() diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 8f38d913..1b9d97e7 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -416,7 +416,7 @@ def cypher_query( handle_unique=True, retry_on_session_expire=False, resolve_objects=False, - ) -> Tuple[list[list], Tuple[str, ...]]: + ): """ Runs a query on the database and returns a list of results and their headers. @@ -469,7 +469,7 @@ def _run_cypher_query( handle_unique, retry_on_session_expire, resolve_objects, - ) -> Tuple[list[list], Tuple[str, ...]]: + ): try: # Retrieve the data start = time.time() From df0efa06049d26b5b8eed9bc94cf96cf31d8f507 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 18 Mar 2024 11:04:57 +0100 Subject: [PATCH 55/73] FIx element id parsing --- neomodel/async_/core.py | 13 ++++++++----- neomodel/async_/match.py | 5 +++-- neomodel/async_/relationship.py | 14 +++++++++++--- neomodel/async_/relationship_manager.py | 20 ++++++++++++++------ neomodel/sync_/core.py | 13 ++++++++----- neomodel/sync_/match.py | 5 +++-- neomodel/sync_/relationship.py | 6 +++--- neomodel/sync_/relationship_manager.py | 14 ++++++++------ test/async_/test_issue283.py | 2 ++ test/async_/test_label_install.py | 9 ++++++--- test/async_/test_migration_neo4j_5.py | 4 ++-- test/async_/test_relationships.py | 2 +- test/sync_/test_issue283.py | 2 ++ test/sync_/test_label_install.py | 3 +++ test/sync_/test_migration_neo4j_5.py | 4 ++-- test/sync_/test_relationships.py | 2 +- 16 files changed, 77 insertions(+), 41 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index f5fe628e..9f4e5e6c 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -523,6 +523,10 @@ async def get_id_method(self) -> str: else: return "elementId" + async def parse_element_id(self, element_id: str): + db_version = await self.database_version + return int(element_id) if db_version.startswith("4") else element_id + async def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: """Returns all indexes existing in the database @@ -1188,7 +1192,9 @@ async def _build_merge_query( from neomodel.async_.match import _rel_helper - query_params["source_id"] = relationship.source.element_id + query_params["source_id"] = await adb.parse_element_id( + relationship.source.element_id + ) query = f"MATCH (source:{relationship.source.__label__}) WHERE {await adb.get_id_method()}(source) = $source_id\n " query += "WITH source\n UNWIND $merge_params as params \n " query += "MERGE " @@ -1318,10 +1324,7 @@ async def cypher(self, query, params=None): """ self._pre_action_check("cypher") params = params or {} - db_version = await adb.database_version - element_id = ( - int(self.element_id) if db_version.startswith("4") else self.element_id - ) + element_id = await adb.parse_element_id(self.element_id) params.update({"self": element_id}) return await adb.cypher_query(query, params) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 9e7e5852..4827b36c 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -504,7 +504,7 @@ async def build_node(self, node): _node_lookup = f"MATCH ({ident}) WHERE {await adb.get_id_method()}({ident})=${place_holder} WITH {ident}" self._ast.lookup = _node_lookup - self._query_params[place_holder] = node.element_id + self._query_params[place_holder] = await adb.parse_element_id(node.element_id) self._ast.return_clause = ident self._ast.result_class = node.__class__ @@ -776,7 +776,8 @@ async def check_contains(self, obj): if isinstance(obj, AsyncStructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: ast = await self.query_cls(self).build_ast() - return await ast._contains(obj.element_id) + obj_element_id = await adb.parse_element_id(obj.element_id) + return await ast._contains(obj_element_id) raise ValueError("Unsaved node: " + repr(obj)) raise ValueError("Expecting StructuredNode instance") diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index cb976bf1..f684b0c0 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -102,7 +102,7 @@ async def save(self): props = self.deflate(self.__properties__) query = f"MATCH ()-[r]->() WHERE {await adb.get_id_method()}(r)=$self " query += "".join([f" SET r.{key} = ${key}" for key in props]) - props["self"] = self.element_id + props["self"] = await adb.parse_element_id(self.element_id) await adb.cypher_query(query, props) @@ -120,7 +120,11 @@ async def start_node(self): WHERE {await adb.get_id_method()}(aNode)=$start_node_element_id RETURN aNode """, - {"start_node_element_id": self._start_node_element_id}, + { + "start_node_element_id": await adb.parse_element_id( + self._start_node_element_id + ) + }, resolve_objects=True, ) return results[0][0][0] @@ -137,7 +141,11 @@ async def end_node(self): WHERE {await adb.get_id_method()}(aNode)=$end_node_element_id RETURN aNode """, - {"end_node_element_id": self._end_node_element_id}, + { + "end_node_element_id": await adb.parse_element_id( + self._end_node_element_id + ) + }, resolve_objects=True, ) return results[0][0][0] diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 80113f94..9c5e7398 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -130,7 +130,7 @@ async def connect(self, node, properties=None): "MERGE" + new_rel ) - params["them"] = node.element_id + params["them"] = await adb.parse_element_id(node.element_id) if not rel_model: await self.source.cypher(q, params) @@ -173,7 +173,9 @@ async def relationship(self, node): + my_rel + f" WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r LIMIT 1" ) - results = await self.source.cypher(q, {"them": node.element_id}) + results = await self.source.cypher( + q, {"them": await adb.parse_element_id(node.element_id)} + ) rels = results[0] if not rels: return @@ -194,7 +196,9 @@ async def all_relationships(self, node): my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) q = f"MATCH {my_rel} WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r " - results = await self.source.cypher(q, {"them": node.element_id}) + results = await self.source.cypher( + q, {"them": await adb.parse_element_id(node.element_id)} + ) rels = results[0] if not rels: return [] @@ -232,12 +236,14 @@ async def reconnect(self, old_node, new_node): old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition) # get list of properties on the existing rel + old_node_element_id = await adb.parse_element_id(old_node.element_id) + new_node_element_id = await adb.parse_element_id(new_node.element_id) result, _ = await self.source.cypher( f""" MATCH (us), (old) WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old MATCH {old_rel} RETURN r """, - {"old": old_node.element_id}, + {"old": old_node_element_id}, ) if result: node_properties = _get_node_properties(result[0][0]) @@ -259,7 +265,7 @@ async def reconnect(self, old_node, new_node): q += " WITH r DELETE r" await self.source.cypher( - q, {"old": old_node.element_id, "new": new_node.element_id} + q, {"old": old_node_element_id, "new": new_node_element_id} ) @check_source @@ -275,7 +281,9 @@ async def disconnect(self, node): MATCH (a), (b) WHERE {await adb.get_id_method()}(a)=$self and {await adb.get_id_method()}(b)=$them MATCH {rel} DELETE r """ - await self.source.cypher(q, {"them": node.element_id}) + await self.source.cypher( + q, {"them": await adb.parse_element_id(node.element_id)} + ) @check_source async def disconnect_all(self): diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 1b9d97e7..1034da56 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -521,6 +521,10 @@ def get_id_method(self) -> str: else: return "elementId" + def parse_element_id(self, element_id: str): + db_version = self.database_version + return int(element_id) if db_version.startswith("4") else element_id + def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: """Returns all indexes existing in the database @@ -1184,7 +1188,9 @@ def _build_merge_query( from neomodel.sync_.match import _rel_helper - query_params["source_id"] = relationship.source.element_id + query_params["source_id"] = db.parse_element_id( + relationship.source.element_id + ) query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " query += "WITH source\n UNWIND $merge_params as params \n " query += "MERGE " @@ -1314,10 +1320,7 @@ def cypher(self, query, params=None): """ self._pre_action_check("cypher") params = params or {} - db_version = db.database_version - element_id = ( - int(self.element_id) if db_version.startswith("4") else self.element_id - ) + element_id = db.parse_element_id(self.element_id) params.update({"self": element_id}) return db.cypher_query(query, params) diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 3ecd3e43..abc5f54d 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -504,7 +504,7 @@ def build_node(self, node): _node_lookup = f"MATCH ({ident}) WHERE {db.get_id_method()}({ident})=${place_holder} WITH {ident}" self._ast.lookup = _node_lookup - self._query_params[place_holder] = node.element_id + self._query_params[place_holder] = db.parse_element_id(node.element_id) self._ast.return_clause = ident self._ast.result_class = node.__class__ @@ -772,7 +772,8 @@ def __contains__(self, obj): if isinstance(obj, StructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: ast = self.query_cls(self).build_ast() - return ast._contains(obj.element_id) + obj_element_id = db.parse_element_id(obj.element_id) + return ast._contains(obj_element_id) raise ValueError("Unsaved node: " + repr(obj)) raise ValueError("Expecting StructuredNode instance") diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 5f0e3f8f..1dbca709 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -102,7 +102,7 @@ def save(self): props = self.deflate(self.__properties__) query = f"MATCH ()-[r]->() WHERE {db.get_id_method()}(r)=$self " query += "".join([f" SET r.{key} = ${key}" for key in props]) - props["self"] = self.element_id + props["self"] = db.parse_element_id(self.element_id) db.cypher_query(query, props) @@ -120,7 +120,7 @@ def start_node(self): WHERE {db.get_id_method()}(aNode)=$start_node_element_id RETURN aNode """, - {"start_node_element_id": self._start_node_element_id}, + {"start_node_element_id": db.parse_element_id(self._start_node_element_id)}, resolve_objects=True, ) return results[0][0][0] @@ -137,7 +137,7 @@ def end_node(self): WHERE {db.get_id_method()}(aNode)=$end_node_element_id RETURN aNode """, - {"end_node_element_id": self._end_node_element_id}, + {"end_node_element_id": db.parse_element_id(self._end_node_element_id)}, resolve_objects=True, ) return results[0][0][0] diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index f96e7b25..8b8ae96a 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -125,7 +125,7 @@ def connect(self, node, properties=None): "MERGE" + new_rel ) - params["them"] = node.element_id + params["them"] = db.parse_element_id(node.element_id) if not rel_model: self.source.cypher(q, params) @@ -168,7 +168,7 @@ def relationship(self, node): + my_rel + f" WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r LIMIT 1" ) - results = self.source.cypher(q, {"them": node.element_id}) + results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) rels = results[0] if not rels: return @@ -189,7 +189,7 @@ def all_relationships(self, node): my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) q = f"MATCH {my_rel} WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r " - results = self.source.cypher(q, {"them": node.element_id}) + results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) rels = results[0] if not rels: return [] @@ -227,12 +227,14 @@ def reconnect(self, old_node, new_node): old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition) # get list of properties on the existing rel + old_node_element_id = db.parse_element_id(old_node.element_id) + new_node_element_id = db.parse_element_id(new_node.element_id) result, _ = self.source.cypher( f""" MATCH (us), (old) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old MATCH {old_rel} RETURN r """, - {"old": old_node.element_id}, + {"old": old_node_element_id}, ) if result: node_properties = _get_node_properties(result[0][0]) @@ -253,7 +255,7 @@ def reconnect(self, old_node, new_node): q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) q += " WITH r DELETE r" - self.source.cypher(q, {"old": old_node.element_id, "new": new_node.element_id}) + self.source.cypher(q, {"old": old_node_element_id, "new": new_node_element_id}) @check_source def disconnect(self, node): @@ -268,7 +270,7 @@ def disconnect(self, node): MATCH (a), (b) WHERE {db.get_id_method()}(a)=$self and {db.get_id_method()}(b)=$them MATCH {rel} DELETE r """ - self.source.cypher(q, {"them": node.element_id}) + self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) @check_source def disconnect_all(self): diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py index 097f0d57..8106a796 100644 --- a/test/async_/test_issue283.py +++ b/test/async_/test_issue283.py @@ -116,6 +116,8 @@ async def test_automatic_result_resolution(): await B.friends_with.connect(C) await C.friends_with.connect(A) + test = await A.friends_with + # If A is friends with B, then A's friends_with objects should be # TechnicalPerson (!NOT basePerson!) assert type((await A.friends_with)[0]) is TechnicalPerson diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py index dc3e961d..2d710a19 100644 --- a/test/async_/test_label_install.py +++ b/test/async_/test_label_install.py @@ -121,8 +121,9 @@ async def test_install_labels_db_property(capsys): await _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") -def test_relationship_unique_index_not_supported(): - if adb.version_is_higher_than("5.7"): +@mark_async_test +async def test_relationship_unique_index_not_supported(): + if await adb.version_is_higher_than("5.7"): pytest.skip("Not supported before 5.7") class UniqueIndexRelationship(AsyncStructuredRel): @@ -142,10 +143,12 @@ class NodeWithUniqueIndexRelationship(AsyncStructuredNode): model=UniqueIndexRelationship, ) + await adb.install_labels(NodeWithUniqueIndexRelationship) + @mark_async_test async def test_relationship_unique_index(): - if not adb.version_is_higher_than("5.7"): + if not await adb.version_is_higher_than("5.7"): pytest.skip("Not supported before 5.7") class UniqueIndexRelationshipBis(AsyncStructuredRel): diff --git a/test/async_/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py index 7b3397dd..48c0e8c4 100644 --- a/test/async_/test_migration_neo4j_5.py +++ b/test/async_/test_migration_neo4j_5.py @@ -44,8 +44,8 @@ async def test_read_elements_id(): assert lex_hives.id == int(lex_hives.element_id) assert lex_hives.id == (await the_hives.released.single()).id # Relationships' ids - assert isinstance(released_rel.element_id, int) - assert released_rel.element_id == released_rel.id + assert isinstance(released_rel.element_id, str) + assert int(released_rel.element_id) == released_rel.id assert released_rel._start_node_id == int(the_hives.element_id) assert released_rel._end_node_id == int(lex_hives.element_id) else: diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py index df50e701..9835e8b7 100644 --- a/test/async_/test_relationships.py +++ b/test/async_/test_relationships.py @@ -99,7 +99,7 @@ async def test_either_direction_connect(): f"""MATCH (us), (them) WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", - {"them": rey.element_id}, + {"them": await adb.parse_element_id(rey.element_id)}, ) assert int(result[0][0]) == 1 diff --git a/test/sync_/test_issue283.py b/test/sync_/test_issue283.py index 877efe0f..a059f7f2 100644 --- a/test/sync_/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -110,6 +110,8 @@ def test_automatic_result_resolution(): B.friends_with.connect(C) C.friends_with.connect(A) + test = A.friends_with + # If A is friends with B, then A's friends_with objects should be # TechnicalPerson (!NOT basePerson!) assert type((A.friends_with)[0]) is TechnicalPerson diff --git a/test/sync_/test_label_install.py b/test/sync_/test_label_install.py index e1d60636..14bfe107 100644 --- a/test/sync_/test_label_install.py +++ b/test/sync_/test_label_install.py @@ -121,6 +121,7 @@ def test_install_labels_db_property(capsys): _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") +@mark_sync_test def test_relationship_unique_index_not_supported(): if db.version_is_higher_than("5.7"): pytest.skip("Not supported before 5.7") @@ -142,6 +143,8 @@ class NodeWithUniqueIndexRelationship(StructuredNode): model=UniqueIndexRelationship, ) + db.install_labels(NodeWithUniqueIndexRelationship) + @mark_sync_test def test_relationship_unique_index(): diff --git a/test/sync_/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py index 10b2bac0..83e090cb 100644 --- a/test/sync_/test_migration_neo4j_5.py +++ b/test/sync_/test_migration_neo4j_5.py @@ -44,8 +44,8 @@ def test_read_elements_id(): assert lex_hives.id == int(lex_hives.element_id) assert lex_hives.id == (the_hives.released.single()).id # Relationships' ids - assert isinstance(released_rel.element_id, int) - assert released_rel.element_id == released_rel.id + assert isinstance(released_rel.element_id, str) + assert int(released_rel.element_id) == released_rel.id assert released_rel._start_node_id == int(the_hives.element_id) assert released_rel._end_node_id == int(lex_hives.element_id) else: diff --git a/test/sync_/test_relationships.py b/test/sync_/test_relationships.py index 8374935f..39057a39 100644 --- a/test/sync_/test_relationships.py +++ b/test/sync_/test_relationships.py @@ -99,7 +99,7 @@ def test_either_direction_connect(): f"""MATCH (us), (them) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", - {"them": rey.element_id}, + {"them": db.parse_element_id(rey.element_id)}, ) assert int(result[0][0]) == 1 From 3219d0838c8eba6bea2d9e33d10214a091dd2d66 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 19 Mar 2024 10:29:12 +0100 Subject: [PATCH 56/73] Fix pre-commit --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd58a3c9..9bcb97ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,4 @@ repos: - - repo: https://github.com/PyCQA/isort - rev: 5.11.5 - hooks: - - id: isort - repo: local hooks: - id: unasync @@ -10,6 +6,10 @@ repos: entry: bin/make-unasync language: system files: "^(neomodel/async_|test/async_)/.*" + - repo: https://github.com/PyCQA/isort + rev: 5.11.5 + hooks: + - id: isort - repo: https://github.com/psf/black rev: 23.3.0 hooks: From bb03faaf2692625601ec0ff1511e2af04d3d98dc Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Mar 2024 14:08:56 +0100 Subject: [PATCH 57/73] Apply isort in unasync --- .pre-commit-config.yaml | 12 +++++++----- bin/make-unasync | 17 +++++++++++++++++ pyproject.toml | 2 +- requirements-dev.txt | 1 + test/async_/test_alias.py | 4 ++-- test/sync_/test_alias.py | 4 ++-- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9bcb97ad..5123f628 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,17 @@ repos: + - repo: https://github.com/PyCQA/isort + rev: 5.11.5 + hooks: + - id: isort + args: ["--profile", "black"] - repo: local hooks: - id: unasync name: unasync entry: bin/make-unasync - language: system + language: python files: "^(neomodel/async_|test/async_)/.*" - - repo: https://github.com/PyCQA/isort - rev: 5.11.5 - hooks: - - id: isort + additional_dependencies: [unasync, isort] - repo: https://github.com/psf/black rev: 23.3.0 hooks: diff --git a/bin/make-unasync b/bin/make-unasync index a9d1d6b7..12241091 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -8,6 +8,8 @@ import sys import tokenize as std_tokenize from pathlib import Path +import isort +import isort.files import unasync ROOT_DIR = Path(__file__).parents[1].absolute() @@ -258,6 +260,20 @@ def apply_unasync(files): return [Path(path) for rule in rules for path in rule.out_files] +def apply_isort(paths): + """Sort imports in generated sync code. + + Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical + order of the import can change. + """ + isort_config = isort.Config(settings_path=str(ROOT_DIR), quiet=True) + + for path in paths: + isort.file(str(path), config=isort_config) + + return paths + + def apply_changes(paths): def files_equal(path1, path2): with open(path1, "rb") as f1: @@ -295,6 +311,7 @@ def main(): if len(sys.argv) >= 1: files = sys.argv[1:] paths = apply_unasync(files) + paths = apply_isort(paths) changed_paths = apply_changes(paths) if changed_paths: diff --git a/pyproject.toml b/pyproject.toml index 1ac483bf..acc36850 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ testpaths = "test" [tool.isort] profile = 'black' -src_paths = ['neomodel'] +src_paths = ['neomodel','test'] [tool.pylint.'MESSAGES CONTROL'] disable = 'missing-module-docstring,redefined-builtin,missing-class-docstring,missing-function-docstring,consider-using-f-string,line-too-long' diff --git a/requirements-dev.txt b/requirements-dev.txt index 2e7a31f0..bf3fa116 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # neomodel -e .[pandas,numpy] +unasync>=0.5.0 pytest>=7.1 pytest-cov>=4.0 pre-commit diff --git a/test/async_/test_alias.py b/test/async_/test_alias.py index 3b0f6529..b2b9a3a8 100644 --- a/test/async_/test_alias.py +++ b/test/async_/test_alias.py @@ -16,9 +16,9 @@ class AliasTestNode(AsyncStructuredNode): @mark_async_test async def test_property_setup_hook(): - tim = await AliasTestNode(long_name="tim").save() + timmy = await AliasTestNode(long_name="timmy").save() assert AliasTestNode.setup_hook_called - assert tim.name == "tim" + assert timmy.name == "timmy" @mark_async_test diff --git a/test/sync_/test_alias.py b/test/sync_/test_alias.py index f266eb82..420e62a0 100644 --- a/test/sync_/test_alias.py +++ b/test/sync_/test_alias.py @@ -16,9 +16,9 @@ class AliasTestNode(StructuredNode): @mark_sync_test def test_property_setup_hook(): - tim = AliasTestNode(long_name="tim").save() + timmy = AliasTestNode(long_name="timmy").save() assert AliasTestNode.setup_hook_called - assert tim.name == "tim" + assert timmy.name == "timmy" @mark_sync_test From 0dc41b3f0bc76ba7e881b5693d2a4a4189fd7fad Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Mar 2024 16:33:45 +0100 Subject: [PATCH 58/73] Fix transactions ; Improve make-unasync --- .pre-commit-config.yaml | 10 ++++----- bin/make-unasync | 24 ++++++++++++++++++++- neomodel/async_/core.py | 36 +++++++++++++++++++------------- neomodel/sync_/core.py | 36 +++++++++++++++++++------------- test/async_/test_cypher.py | 9 ++++++++ test/async_/test_transactions.py | 18 +++++++--------- test/sync_/test_cypher.py | 9 ++++++++ test/sync_/test_transactions.py | 12 ++++------- 8 files changed, 101 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5123f628..22db6a97 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,10 @@ repos: hooks: - id: isort args: ["--profile", "black"] + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black - repo: local hooks: - id: unasync @@ -11,8 +15,4 @@ repos: entry: bin/make-unasync language: python files: "^(neomodel/async_|test/async_)/.*" - additional_dependencies: [unasync, isort] - - repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black \ No newline at end of file + additional_dependencies: [unasync, isort, black] \ No newline at end of file diff --git a/bin/make-unasync b/bin/make-unasync index 12241091..917353f8 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -8,6 +8,7 @@ import sys import tokenize as std_tokenize from pathlib import Path +import black import isort import isort.files import unasync @@ -260,13 +261,33 @@ def apply_unasync(files): return [Path(path) for rule in rules for path in rule.out_files] +def apply_black(paths): + """Prettify generated sync code. + + Since keywords are removed, black might expect a different result, + especially line breaks. + """ + for path in paths: + with open(path, "r") as file: + code = file.read() + + formatted_code = black.format_str(code, mode=black.FileMode()) + + with open(path, "w") as file: + file.write(formatted_code) + + return paths + + def apply_isort(paths): """Sort imports in generated sync code. Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical order of the import can change. """ - isort_config = isort.Config(settings_path=str(ROOT_DIR), quiet=True) + isort_config = isort.Config( + settings_path=str(ROOT_DIR), quiet=True, profile="black" + ) for path in paths: isort.file(str(path), config=isort_config) @@ -312,6 +333,7 @@ def main(): files = sys.argv[1:] paths = apply_unasync(files) paths = apply_isort(paths) + paths = apply_black(paths) changed_paths = apply_changes(paths) if changed_paths: diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index 9f4e5e6c..d92c8dd1 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -3,6 +3,7 @@ import sys import time import warnings +from asyncio import iscoroutinefunction from itertools import combinations from threading import local from typing import Optional, Sequence, Tuple @@ -48,6 +49,7 @@ INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" +NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" # make sure the connection url has been set prior to executing the wrapped function @@ -951,6 +953,9 @@ async def __aexit__(self, exc_type, exc_value, traceback): self.last_bookmark = await self.db.commit() def __call__(self, func): + if not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + async def wrapper(*args, **kwargs): async with self: print("call called") @@ -963,6 +968,23 @@ def with_bookmark(self): return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) +class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): + def __call__(self, func): + if not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + async def wrapper(*args, **kwargs): + self.bookmarks = kwargs.pop("bookmarks", None) + + async with self: + result = await func(*args, **kwargs) + self.last_bookmark = None + + return result, self.last_bookmark + + return wrapper + + class ImpersonationHandler: def __init__(self, db: AsyncDatabase, impersonated_user: str): self.db = db @@ -987,20 +1009,6 @@ def wrapper(*args, **kwargs): return wrapper -class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): - def __call__(self, func): - def wrapper(*args, **kwargs): - self.bookmarks = kwargs.pop("bookmarks", None) - - with self: - result = func(*args, **kwargs) - self.last_bookmark = None - - return result, self.last_bookmark - - return wrapper - - class NodeMeta(type): def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 1034da56..2fbec35b 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -3,6 +3,7 @@ import sys import time import warnings +from asyncio import iscoroutinefunction from itertools import combinations from threading import local from typing import Optional, Sequence, Tuple @@ -48,6 +49,7 @@ INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" +NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" # make sure the connection url has been set prior to executing the wrapped function @@ -947,6 +949,9 @@ def __exit__(self, exc_type, exc_value, traceback): self.last_bookmark = self.db.commit() def __call__(self, func): + if not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + def wrapper(*args, **kwargs): with self: print("call called") @@ -959,6 +964,23 @@ def with_bookmark(self): return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) +class BookmarkingAsyncTransactionProxy(TransactionProxy): + def __call__(self, func): + if not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + def wrapper(*args, **kwargs): + self.bookmarks = kwargs.pop("bookmarks", None) + + with self: + result = func(*args, **kwargs) + self.last_bookmark = None + + return result, self.last_bookmark + + return wrapper + + class ImpersonationHandler: def __init__(self, db: Database, impersonated_user: str): self.db = db @@ -983,20 +1005,6 @@ def wrapper(*args, **kwargs): return wrapper -class BookmarkingAsyncTransactionProxy(TransactionProxy): - def __call__(self, func): - def wrapper(*args, **kwargs): - self.bookmarks = kwargs.pop("bookmarks", None) - - with self: - result = func(*args, **kwargs) - self.last_bookmark = None - - return result, self.last_bookmark - - return wrapper - - class NodeMeta(type): def __new__(mcs, name, bases, namespace): namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index ac310307..cbdf229c 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -7,6 +7,7 @@ from pandas import DataFrame, Series from neomodel import AsyncStructuredNode, StringProperty, adb +from neomodel._async_compat.util import AsyncUtil class User2(AsyncStructuredNode): @@ -76,6 +77,10 @@ async def test_cypher_syntax_error(): @mark_async_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) async def test_pandas_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when pandas.py is imported + if not AsyncUtil.is_async_code: + pytest.skip("This test is async only") with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -128,6 +133,10 @@ async def test_pandas_integration(): @mark_async_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) async def test_numpy_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when pandas.py is imported + if not AsyncUtil.is_async_code: + pytest.skip("This test is async only") with pytest.raises(ImportError): with pytest.warns( UserWarning, diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index d115f186..7add7265 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -38,8 +38,6 @@ async def in_a_tx(*names): await APerson(name=n).save() -# TODO : This fails with no support for context manager protocol -# Possibly the transaction decorator is the issue @mark_async_test async def test_transaction_decorator(): await adb.install_labels(APerson) @@ -54,7 +52,7 @@ async def test_transaction_decorator(): with raises(UniqueProperty): await in_a_tx("Jim", "Roger") - assert "Jim" not in [p.name async for p in await APerson.nodes] + assert "Jim" not in [p.name for p in await APerson.nodes] @mark_async_test @@ -115,25 +113,24 @@ async def double_transaction(): @adb.transaction.with_bookmark -async def in_a_tx(*names): +async def in_a_tx_with_bookmark(*names): for n in names: await APerson(name=n).save() -# TODO : FIx this once in_a_tx is fixed @mark_async_test async def test_bookmark_transaction_decorator(): for p in await APerson.nodes: await p.delete() # should work - result, bookmarks = await in_a_tx("Ruth", bookmarks=None) + result, bookmarks = await in_a_tx_with_bookmark("Ruth", bookmarks=None) assert result is None assert isinstance(bookmarks, Bookmarks) # should bail but raise correct error with raises(UniqueProperty): - await in_a_tx("Jane", "Ruth") + await in_a_tx_with_bookmark("Jane", "Ruth") assert "Jane" not in [p.name for p in await APerson.nodes] @@ -153,9 +150,9 @@ async def test_bookmark_transaction_as_a_context(): @pytest.fixture -async def spy_on_db_begin(monkeypatch): +def spy_on_db_begin(monkeypatch): spy_calls = [] - original_begin = await adb.begin() + original_begin = adb.begin def begin_spy(*args, **kwargs): spy_calls.append((args, kwargs)) @@ -165,14 +162,13 @@ def begin_spy(*args, **kwargs): return spy_calls -# TODO : Fix this test @mark_async_test async def test_bookmark_passed_in_to_context(spy_on_db_begin): transaction = adb.transaction async with transaction: pass - assert (await spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) + assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) last_bookmark = transaction.last_bookmark transaction.bookmarks = last_bookmark diff --git a/test/sync_/test_cypher.py b/test/sync_/test_cypher.py index ab8b6d65..35681355 100644 --- a/test/sync_/test_cypher.py +++ b/test/sync_/test_cypher.py @@ -7,6 +7,7 @@ from pandas import DataFrame, Series from neomodel import StringProperty, StructuredNode, db +from neomodel._async_compat.util import Util class User2(StructuredNode): @@ -74,6 +75,10 @@ def test_cypher_syntax_error(): @mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) def test_pandas_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when pandas.py is imported + if not Util.is_async_code: + pytest.skip("This test is only") with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -120,6 +125,10 @@ def test_pandas_integration(): @mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) def test_numpy_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when pandas.py is imported + if not Util.is_async_code: + pytest.skip("This test is only") with pytest.raises(ImportError): with pytest.warns( UserWarning, diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index f9c1b2b6..de2ce150 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -38,8 +38,6 @@ def in_a_tx(*names): APerson(name=n).save() -# TODO : This fails with no support for context manager protocol -# Possibly the transaction decorator is the issue @mark_sync_test def test_transaction_decorator(): db.install_labels(APerson) @@ -115,25 +113,24 @@ def double_transaction(): @db.transaction.with_bookmark -def in_a_tx(*names): +def in_a_tx_with_bookmark(*names): for n in names: APerson(name=n).save() -# TODO : FIx this once in_a_tx is fixed @mark_sync_test def test_bookmark_transaction_decorator(): for p in APerson.nodes: p.delete() # should work - result, bookmarks = in_a_tx("Ruth", bookmarks=None) + result, bookmarks = in_a_tx_with_bookmark("Ruth", bookmarks=None) assert result is None assert isinstance(bookmarks, Bookmarks) # should bail but raise correct error with raises(UniqueProperty): - in_a_tx("Jane", "Ruth") + in_a_tx_with_bookmark("Jane", "Ruth") assert "Jane" not in [p.name for p in APerson.nodes] @@ -155,7 +152,7 @@ def test_bookmark_transaction_as_a_context(): @pytest.fixture def spy_on_db_begin(monkeypatch): spy_calls = [] - original_begin = db.begin() + original_begin = db.begin def begin_spy(*args, **kwargs): spy_calls.append((args, kwargs)) @@ -165,7 +162,6 @@ def begin_spy(*args, **kwargs): return spy_calls -# TODO : Fix this test @mark_sync_test def test_bookmark_passed_in_to_context(spy_on_db_begin): transaction = db.transaction From da1c0d800a62edb83bee9808446749f01a86ef89 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 25 Mar 2024 16:36:45 +0100 Subject: [PATCH 59/73] Fix check coroutine --- neomodel/async_/core.py | 7 ++++--- neomodel/sync_/core.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index d92c8dd1..21228cb9 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -6,7 +6,7 @@ from asyncio import iscoroutinefunction from itertools import combinations from threading import local -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -23,6 +23,7 @@ from neo4j.graph import Node, Path, Relationship from neomodel import config +from neomodel._async_compat.util import AsyncUtil from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.exceptions import ( ConstraintValidationFailed, @@ -953,7 +954,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): self.last_bookmark = await self.db.commit() def __call__(self, func): - if not iscoroutinefunction(func): + if AsyncUtil.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) async def wrapper(*args, **kwargs): @@ -970,7 +971,7 @@ def with_bookmark(self): class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): def __call__(self, func): - if not iscoroutinefunction(func): + if AsyncUtil.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) async def wrapper(*args, **kwargs): diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 2fbec35b..8778adb0 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -6,7 +6,7 @@ from asyncio import iscoroutinefunction from itertools import combinations from threading import local -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence from urllib.parse import quote, unquote, urlparse from neo4j import ( @@ -23,6 +23,7 @@ from neo4j.graph import Node, Path, Relationship from neomodel import config +from neomodel._async_compat.util import Util from neomodel.exceptions import ( ConstraintValidationFailed, DoesNotExist, @@ -949,7 +950,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.last_bookmark = self.db.commit() def __call__(self, func): - if not iscoroutinefunction(func): + if Util.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) def wrapper(*args, **kwargs): @@ -966,7 +967,7 @@ def with_bookmark(self): class BookmarkingAsyncTransactionProxy(TransactionProxy): def __call__(self, func): - if not iscoroutinefunction(func): + if Util.is_async_code and not iscoroutinefunction(func): raise TypeError(NOT_COROUTINE_ERROR) def wrapper(*args, **kwargs): From ff058d9de1a4ebd00ecbb55c63bd044c3646c177 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 26 Mar 2024 17:42:30 +0100 Subject: [PATCH 60/73] Fix test order --- test/conftest.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 48d3088b..291eedf5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -28,23 +28,32 @@ def pytest_addoption(parser): @pytest.hookimpl def pytest_collection_modifyitems(items): - connect_to_aura_items = [] - normal_items = [] + async_items = [] + sync_items = [] + async_connect_to_aura_items = [] + sync_connect_to_aura_items = [] - # Separate all tests into two groups: those with "connect_to_aura" in their name, and all others for item in items: + # Check the directory of the item + directory = item.fspath.dirname.split("/")[-1] + if "connect_to_aura" in item.name: - connect_to_aura_items.append(item) + if directory == "async_": + async_connect_to_aura_items.append(item) + elif directory == "sync_": + sync_connect_to_aura_items.append(item) else: - normal_items.append(item) - - # Add all normal tests back to the front of the list - new_order = normal_items - - # Add all connect_to_aura tests to the end of the list - new_order.extend(connect_to_aura_items) - - # Replace the original items list with the new order + if directory == "async_": + async_items.append(item) + elif directory == "sync_": + sync_items.append(item) + + new_order = ( + async_items + + async_connect_to_aura_items + + sync_items + + sync_connect_to_aura_items + ) items[:] = new_order From d4b48baa7f915aef203e546c2f3d8861019280d5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 28 Mar 2024 09:41:35 +0100 Subject: [PATCH 61/73] Fix sync conftest --- bin/make-unasync | 7 ++++++- test/sync_/conftest.py | 16 ++-------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/bin/make-unasync b/bin/make-unasync index 917353f8..96375526 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -20,6 +20,7 @@ ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_" +INTEGRATION_TEST_EXCLUSION_LIST = ["conftest.py"] UNASYNC_SUFFIX = ".unasync" PY_FILE_EXTENSIONS = {".py"} @@ -248,7 +249,11 @@ def apply_unasync(files): if not files: paths = list(ASYNC_DIR.rglob("*")) paths += list(ASYNC_CONTRIB_DIR.rglob("*")) - paths += list(ASYNC_INTEGRATION_TEST_DIR.rglob("*")) + paths += [ + path + for path in ASYNC_INTEGRATION_TEST_DIR.rglob("*") + if path.name not in INTEGRATION_TEST_EXCLUSION_LIST + ] else: paths = [ROOT_DIR / Path(f) for f in files] filtered_paths = [] diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py index 735c9840..d2cd787e 100644 --- a/test/sync_/conftest.py +++ b/test/sync_/conftest.py @@ -1,15 +1,12 @@ -import asyncio import os import warnings from test._async_compat import mark_sync_session_auto_fixture -import pytest - from neomodel import config, db @mark_sync_session_auto_fixture -def setup_neo4j_session(request, event_loop): +def setup_neo4j_session(request): """ Provides initial connection to the database and sets up the rest of the test suite @@ -46,15 +43,6 @@ def setup_neo4j_session(request, event_loop): @mark_sync_session_auto_fixture -def cleanup(event_loop): +def cleanup(): yield db.close_connection() - - -@pytest.fixture(scope="session") -def event_loop(): - """Overrides pytest default function scoped event loop""" - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - yield loop - loop.close() From a229418559f8ad5de260e82504199fba71addeac Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 28 Mar 2024 09:57:20 +0100 Subject: [PATCH 62/73] Fix numpy test --- test/async_/test_cypher.py | 4 ++-- test/sync_/test_cypher.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index cbdf229c..c078c8d5 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -134,13 +134,13 @@ async def test_pandas_integration(): @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) async def test_numpy_not_installed(hide_available_pkg): # We run only the async version, because this fails on second run - # because import error is thrown only when pandas.py is imported + # because import error is thrown only when numpy.py is imported if not AsyncUtil.is_async_code: pytest.skip("This test is async only") with pytest.raises(ImportError): with pytest.warns( UserWarning, - match="The neomodel.integration.numpy module expects pandas to be installed", + match="The neomodel.integration.numpy module expects numpy to be installed", ): from neomodel.integration.numpy import to_ndarray diff --git a/test/sync_/test_cypher.py b/test/sync_/test_cypher.py index 35681355..944959c4 100644 --- a/test/sync_/test_cypher.py +++ b/test/sync_/test_cypher.py @@ -126,13 +126,13 @@ def test_pandas_integration(): @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) def test_numpy_not_installed(hide_available_pkg): # We run only the async version, because this fails on second run - # because import error is thrown only when pandas.py is imported + # because import error is thrown only when numpy.py is imported if not Util.is_async_code: pytest.skip("This test is only") with pytest.raises(ImportError): with pytest.warns( UserWarning, - match="The neomodel.integration.numpy module expects pandas to be installed", + match="The neomodel.integration.numpy module expects numpy to be installed", ): from neomodel.integration.numpy import to_ndarray From 8bd7aa02bbebc5be7558e8c3ed27ea37dc1b226d Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 28 Mar 2024 17:25:19 +0100 Subject: [PATCH 63/73] Fix some code smells --- doc/source/getting_started.rst | 1 + neomodel/__init__.py | 2 +- neomodel/async_/match.py | 7 ------- neomodel/async_/relationship.py | 20 ++++++++------------ neomodel/config.py | 3 +++ neomodel/sync_/match.py | 7 ------- neomodel/sync_/relationship.py | 20 ++++++++------------ test/async_/test_models.py | 5 ++--- test/async_/test_properties.py | 1 - test/async_/test_transactions.py | 1 - test/sync_/test_models.py | 5 ++--- test/sync_/test_properties.py | 1 - test/sync_/test_transactions.py | 1 - 13 files changed, 25 insertions(+), 49 deletions(-) diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 7be18428..c756a6f6 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -296,6 +296,7 @@ Most _dunder_ methods for nodes and relationships had to be overriden to support dogs_bonanza = await Dog.nodes.get_len() # Sync equivalent - __len__ dogs_bonanza = len(Dog.nodes) + # Note that len(Dog.nodes) is more efficient than Dog.nodes.__len__ # Existence assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() diff --git a/neomodel/__init__.py b/neomodel/__init__.py index e10f571f..d7d0febb 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,5 +1,4 @@ # pep8: noqa -# TODO : Check imports sync + async from neomodel.async_.cardinality import ( AsyncOne, AsyncOneOrMore, @@ -9,6 +8,7 @@ from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.async_.match import AsyncNodeSet, AsyncTraversal from neomodel.async_.path import AsyncNeomodelPath +from neomodel.async_.property_manager import AsyncPropertyManager from neomodel.async_.relationship import AsyncStructuredRel from neomodel.async_.relationship_manager import ( AsyncRelationship, diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 4827b36c..e7e82adb 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -743,13 +743,6 @@ async def __aiter__(self): async for i in await ast._execute(): yield i - # TODO : Add tests for sync to check that len(Label.nodes) is still working - # Because async tests will now check for Coffee.nodes.get_len() - # Also add documenation for get_len, check_bool, etc... - # Documentation should explain that in sync, assert node1.extension is more efficient than - # assert node1.extension.check_bool() because it counts using native Cypher - # Same for len(Extension.nodes) vs Extension.nodes.__len__ - # With note that async does not have a choice async def get_len(self): ast = await self.query_cls(self).build_ast() return await ast._count() diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index f684b0c0..5653137a 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -3,6 +3,8 @@ from neomodel.hooks import hooks from neomodel.properties import Property +ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + class RelationshipMeta(type): def __new__(mcs, name, bases, dct): @@ -67,30 +69,24 @@ def _end_node_element_id(self): def id(self): try: return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc # Version 4.4 support - id is deprecated in version 5.x @property def _start_node_id(self): try: return int(self._start_node_element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc # Version 4.4 support - id is deprecated in version 5.x @property def _end_node_id(self): try: return int(self._end_node_element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc @hooks async def save(self): diff --git a/neomodel/config.py b/neomodel/config.py index 2e527782..85c9ed8a 100644 --- a/neomodel/config.py +++ b/neomodel/config.py @@ -22,3 +22,6 @@ # DRIVER = neo4j.GraphDatabase().driver( # "bolt://localhost:7687", auth=("neo4j", "foobarbaz") # ) +DRIVER = None +# Use this to connect to a specific database when using the self-managed driver +DATABASE_NAME = None diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index abc5f54d..fad8b05f 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -739,13 +739,6 @@ def __iter__(self): for i in ast._execute(): yield i - # TODO : Add tests for sync to check that len(Label.nodes) is still working - # Because async tests will now check for Coffee.nodes.get_len() - # Also add documenation for get_len, check_bool, etc... - # Documentation should explain that in sync, assert node1.extension is more efficient than - # assert node1.extension.check_bool() because it counts using native Cypher - # Same for len(Extension.nodes) vs Extension.nodes.__len__ - # With note that async does not have a choice def __len__(self): ast = self.query_cls(self).build_ast() return ast._count() diff --git a/neomodel/sync_/relationship.py b/neomodel/sync_/relationship.py index 1dbca709..0a199575 100644 --- a/neomodel/sync_/relationship.py +++ b/neomodel/sync_/relationship.py @@ -3,6 +3,8 @@ from neomodel.sync_.core import db from neomodel.sync_.property_manager import PropertyManager +ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + class RelationshipMeta(type): def __new__(mcs, name, bases, dct): @@ -67,30 +69,24 @@ def _end_node_element_id(self): def id(self): try: return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc # Version 4.4 support - id is deprecated in version 5.x @property def _start_node_id(self): try: return int(self._start_node_element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc # Version 4.4 support - id is deprecated in version 5.x @property def _end_node_id(self): try: return int(self._end_node_element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc @hooks def save(self): diff --git a/test/async_/test_models.py b/test/async_/test_models.py index 93bee6ab..b9bb2e44 100644 --- a/test/async_/test_models.py +++ b/test/async_/test_models.py @@ -61,9 +61,8 @@ async def test_required(): def test_repr_and_str(): u = User(email="robin@test.com", age=3) - print(repr(u)) - print(str(u)) - assert True + assert repr(u) == "" + assert str(u) == "{'email': 'robin@test.com', 'age': 3}" @mark_async_test diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 58db3047..0679cf89 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -84,7 +84,6 @@ def test_deflate_inflate(): try: prop.inflate("six") except InflateError as e: - assert True assert "inflate property" in str(e) else: assert False, "DeflateError not raised." diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py index 7add7265..59d523c5 100644 --- a/test/async_/test_transactions.py +++ b/test/async_/test_transactions.py @@ -46,7 +46,6 @@ async def test_transaction_decorator(): # should work await in_a_tx("Roger") - assert True # should bail but raise correct error with raises(UniqueProperty): diff --git a/test/sync_/test_models.py b/test/sync_/test_models.py index 13065c3b..3698b612 100644 --- a/test/sync_/test_models.py +++ b/test/sync_/test_models.py @@ -61,9 +61,8 @@ def test_required(): def test_repr_and_str(): u = User(email="robin@test.com", age=3) - print(repr(u)) - print(str(u)) - assert True + assert repr(u) == "" + assert str(u) == "{'email': 'robin@test.com', 'age': 3}" @mark_sync_test diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 2f5a444e..28866738 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -84,7 +84,6 @@ def test_deflate_inflate(): try: prop.inflate("six") except InflateError as e: - assert True assert "inflate property" in str(e) else: assert False, "DeflateError not raised." diff --git a/test/sync_/test_transactions.py b/test/sync_/test_transactions.py index de2ce150..834b538e 100644 --- a/test/sync_/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -46,7 +46,6 @@ def test_transaction_decorator(): # should work in_a_tx("Roger") - assert True # should bail but raise correct error with raises(UniqueProperty): From ce3a1cf98ae3c2a4ef74a13eb2d24388f2a96050 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 9 Apr 2024 10:17:29 +0200 Subject: [PATCH 64/73] Update README with performance tests --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 9fe5e06b..0bd3478e 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,15 @@ To install from github: $ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev +# Performance comparison + +You can find some performance tests made using Locust [in this repo](https://github.com/mariusconjeaud/neomodel-locust). + +Two learnings from this : + +* The wrapping of the driver made by neomodel is very thin performance-wise : it does not add a lot of overhead ; +* When used in a concurrent fashion, async neomodel is faster than concurrent sync neomodel, and a lot of faster than serial queries. + # Contributing Ideas, bugs, tests and pull requests always welcome. Please use From 5d760f0a1a83849e867c294ac71501244f045efe Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 9 Apr 2024 10:36:37 +0200 Subject: [PATCH 65/73] Bump neo4j ; update changelog and version --- Changelog | 4 ++++ doc/source/configuration.rst | 2 +- neomodel/_version.py | 2 +- pyproject.toml | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Changelog b/Changelog index 69be5da2..a71d5052 100644 --- a/Changelog +++ b/Changelog @@ -1,3 +1,7 @@ +Version 5.3.0 2024-04 +* Add async support +* Bumps neo4j (driver) to 5.19.0 + Version 5.2.1 2023-12 * Add options to inspection script to skip heavy operations - rel props or cardinality inspection #767 * Fixes database version parsing issues diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index 3946ec98..95fbb992 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default config.RESOLVER = None # default config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default - config.USER_AGENT = neomodel/v5.2.1 # default + config.USER_AGENT = neomodel/v5.3.0 # default Setting the database name, if different from the default one:: diff --git a/neomodel/_version.py b/neomodel/_version.py index 98886d26..f5752882 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.2.1" +__version__ = "5.3.0" diff --git a/pyproject.toml b/pyproject.toml index acc36850..95d8e719 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Topic :: Database", ] dependencies = [ - "neo4j~=5.15.0", + "neo4j~=5.19.0", ] requires-python = ">=3.7" dynamic = ["version"] From c1057c2bf5baae723ef9017b007317658ec912bf Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 9 Apr 2024 10:38:16 +0200 Subject: [PATCH 66/73] Add breaking change to Changelog --- Changelog | 1 + 1 file changed, 1 insertion(+) diff --git a/Changelog b/Changelog index a71d5052..dd11f1eb 100644 --- a/Changelog +++ b/Changelog @@ -1,6 +1,7 @@ Version 5.3.0 2024-04 * Add async support * Bumps neo4j (driver) to 5.19.0 +* Breaking change : config.AUTO_INSTALL_LABELS has been removed. Please use the neomodel_install_labels script instead Version 5.2.1 2023-12 * Add options to inspection script to skip heavy operations - rel props or cardinality inspection #767 From 24c549dad672aa7cf899d541acfcb71ed434033e Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 9 Apr 2024 10:50:01 +0200 Subject: [PATCH 67/73] Fix flaky test --- test/async_/test_cypher.py | 12 ++++++++---- test/sync_/test_cypher.py | 16 ++++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py index c078c8d5..88447482 100644 --- a/test/async_/test_cypher.py +++ b/test/async_/test_cypher.py @@ -101,7 +101,7 @@ async def test_pandas_integration(): # Test to_dataframe df = to_dataframe( await adb.cypher_query( - "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email ORDER BY name" ) ) @@ -112,7 +112,7 @@ async def test_pandas_integration(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( await adb.cypher_query( - "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email ORDER BY name" ), index=df["email"], dtype=str, @@ -122,7 +122,9 @@ async def test_pandas_integration(): # Next test to_series series = to_series( - await adb.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") + await adb.cypher_query( + "MATCH (a:UserPandas) RETURN a.name AS name ORDER BY name" + ) ) assert isinstance(series, Series) @@ -144,7 +146,9 @@ async def test_numpy_not_installed(hide_available_pkg): ): from neomodel.integration.numpy import to_ndarray - _ = to_ndarray(await adb.cypher_query("MATCH (a) RETURN a.name AS name")) + _ = to_ndarray( + await adb.cypher_query("MATCH (a) RETURN a.name AS name ORDER BY name") + ) @mark_async_test diff --git a/test/sync_/test_cypher.py b/test/sync_/test_cypher.py index 944959c4..07da5a25 100644 --- a/test/sync_/test_cypher.py +++ b/test/sync_/test_cypher.py @@ -98,7 +98,9 @@ def test_pandas_integration(): # Test to_dataframe df = to_dataframe( - db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email") + db.cypher_query( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email ORDER BY name" + ) ) assert isinstance(df, DataFrame) @@ -107,7 +109,9 @@ def test_pandas_integration(): # Also test passing an index and dtype to to_dataframe df = to_dataframe( - db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email"), + db.cypher_query( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email ORDER BY name" + ), index=df["email"], dtype=str, ) @@ -115,7 +119,9 @@ def test_pandas_integration(): assert df.index.inferred_type == "string" # Next test to_series - series = to_series(db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name")) + series = to_series( + db.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name ORDER BY name") + ) assert isinstance(series, Series) assert series.shape == (2,) @@ -136,7 +142,9 @@ def test_numpy_not_installed(hide_available_pkg): ): from neomodel.integration.numpy import to_ndarray - _ = to_ndarray(db.cypher_query("MATCH (a) RETURN a.name AS name")) + _ = to_ndarray( + db.cypher_query("MATCH (a) RETURN a.name AS name ORDER BY name") + ) @mark_sync_test From 238212586f619f152b75c539113e69913e8bb678 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 11 Apr 2024 14:52:29 +0200 Subject: [PATCH 68/73] Update README --- README.md | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 0bd3478e..ba57297b 100644 --- a/README.md +++ b/README.md @@ -37,21 +37,18 @@ GitHub repo found at . Available on [readthedocs](http://neomodel.readthedocs.org). -# Upcoming breaking changes notice - \>=5.3 +# New in 5.3.0 -Based on Python version [status](https://devguide.python.org/versions/), -neomodel will be dropping support for Python 3.7 in an upcoming release -(5.3 or later). This does not mean neomodel will stop working on Python 3.7, but -it will no longer be tested against it. Instead, we will try to add -support for Python 3.12. +neomodel now supports asynchronous programming, thanks to the [Neo4j driver async API](https://neo4j.com/docs/api/python-driver/current/async_api.html). The [documentation](http://neomodel.readthedocs.org) has been updated accordingly, with an updated getting started section, and some specific documentation for the async API. -Another source of upcoming breaking changes is the addition async support to -neomodel. No date is set yet, but the work has progressed a lot in the past weeks ; -and it will be part of a major release. -You can see the progress in [this branch](https://github.com/neo4j-contrib/neomodel/tree/task/async). +# Breaking changes in 5.3.0 -Finally, we are looking at refactoring some standalone methods into the -Database() class. More to come on that later. +- config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` script instead. _Note : this is because of the addition of async, but also because it might lead to uncontrolled creation of indexes/constraints. The script makes you more in control of said creation._ +- Based on Python version [status](https://devguide.python.org/versions/), +neomodel will be dropping support for Python 3.7 in an upcoming release +(5.3 or later). _This does not mean neomodel will stop working on Python 3.7, but +it will no longer be tested against it_ +- Some standalone methods have been refactored into the Database() class. Check the [documentation](http://neomodel.readthedocs.org) for a full list. # Installation From 510a81e7f27a690f852ef1266882d03500d5160a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Thu, 11 Apr 2024 14:52:54 +0200 Subject: [PATCH 69/73] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ba57297b..e3243651 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Available on neomodel now supports asynchronous programming, thanks to the [Neo4j driver async API](https://neo4j.com/docs/api/python-driver/current/async_api.html). The [documentation](http://neomodel.readthedocs.org) has been updated accordingly, with an updated getting started section, and some specific documentation for the async API. -# Breaking changes in 5.3.0 +# Breaking change in 5.3.0 - config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` script instead. _Note : this is because of the addition of async, but also because it might lead to uncontrolled creation of indexes/constraints. The script makes you more in control of said creation._ - Based on Python version [status](https://devguide.python.org/versions/), From 720a5f7b8773d234f9cbf4611003729fc4641f6a Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 12 Apr 2024 14:35:42 +0200 Subject: [PATCH 70/73] #784 Allow filtering by IN in ArrayProperty --- neomodel/async_/match.py | 110 +++++++++++++++++++++++++++------- neomodel/sync_/match.py | 18 +----- test/async_/test_match_api.py | 22 +++++++ test/sync_/test_match_api.py | 22 ++++++- 4 files changed, 134 insertions(+), 38 deletions(-) diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index e7e82adb..12bf6ed2 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -7,7 +7,7 @@ from neomodel.async_.core import AsyncStructuredNode, adb from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase -from neomodel.properties import AliasProperty +from neomodel.properties import AliasProperty, ArrayProperty from neomodel.util import INCOMING, OUTGOING @@ -149,6 +149,7 @@ def _rel_merge_helper( # special operators _SPECIAL_OPERATOR_IN = "IN" +_SPECIAL_OPERATOR_ARRAY_IN = "any(x IN {ident}.{prop} WHERE x IN {val})" _SPECIAL_OPERATOR_INSENSITIVE = "(?i)" _SPECIAL_OPERATOR_ISNULL = "IS NULL" _SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL" @@ -253,29 +254,89 @@ def process_filter_args(cls, kwargs): return output +def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): + """ + Transform in operator to a cypher filter + Args: + operator (str): operator to transform + filter_key (str): filter key + filter_value (str): filter value + property_obj (object): property object + Returns: + tuple: operator, deflated_value + """ + if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): + raise ValueError( + f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" + ) + if isinstance(property_obj, ArrayProperty): + deflated_value = property_obj.deflate(filter_value) + operator = _SPECIAL_OPERATOR_ARRAY_IN + else: + deflated_value = [property_obj.deflate(v) for v in filter_value] + + return operator, deflated_value + + +def transform_null_operator_to_filter(filter_key, filter_value): + """ + Transform null operator to a cypher filter + Args: + filter_key (str): filter key + filter_value (str): filter value + Returns: + tuple: operator, deflated_value + """ + if not isinstance(filter_value, bool): + raise ValueError(f"Value must be a bool for isnull operation on {filter_key}") + operator = "IS NULL" if filter_value else "IS NOT NULL" + deflated_value = None + return operator, deflated_value + + +def transform_regex_operator_to_filter( + operator, filter_key, filter_value, property_obj +): + """ + Transform regex operator to a cypher filter + Args: + operator (str): operator to transform + filter_key (str): filter key + filter_value (str): filter value + property_obj (object): property object + Returns: + tuple: operator, deflated_value + """ + + deflated_value = property_obj.deflate(filter_value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {filter_key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + return operator, deflated_value + + def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): - # handle special operators if operator == _SPECIAL_OPERATOR_IN: - if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): - raise ValueError( - f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" - ) - deflated_value = [property_obj.deflate(v) for v in filter_value] + operator, deflated_value = transform_in_operator_to_filter( + operator=operator, + filter_key=filter_key, + filter_value=filter_value, + property_obj=property_obj, + ) elif operator == _SPECIAL_OPERATOR_ISNULL: - if not isinstance(filter_value, bool): - raise ValueError( - f"Value must be a bool for isnull operation on {filter_key}" - ) - operator = "IS NULL" if filter_value else "IS NOT NULL" - deflated_value = None + operator, deflated_value = transform_null_operator_to_filter( + filter_key=filter_key, filter_value=filter_value + ) elif operator in _REGEX_OPERATOR_TABLE.values(): - deflated_value = property_obj.deflate(filter_value) - if not isinstance(deflated_value, str): - raise ValueError(f"Must be a string value for {filter_key}") - if operator in _STRING_REGEX_OPERATOR_TABLE.values(): - deflated_value = re.escape(deflated_value) - deflated_value = operator.format(deflated_value) - operator = _SPECIAL_OPERATOR_REGEX + operator, deflated_value = transform_regex_operator_to_filter( + operator=operator, + filter_key=filter_key, + filter_value=filter_value, + property_obj=property_obj, + ) else: deflated_value = property_obj.deflate(filter_value) @@ -571,7 +632,14 @@ def _parse_q_filters(self, ident, q, source_class): statement = f"{ident}.{prop} {operator}" else: place_holder = self._register_place_holder(ident + "_" + prop) - statement = f"{ident}.{prop} {operator} ${place_holder}" + if operator == _SPECIAL_OPERATOR_ARRAY_IN: + statement = operator.format( + ident=ident, + prop=prop, + val=f"${place_holder}", + ) + else: + statement = f"{ident}.{prop} {operator} ${place_holder}" self._query_params[place_holder] = val target.append(statement) ret = f" {q.connector} ".join(target) diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index c132cf9a..5efe39c4 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -6,7 +6,7 @@ from neomodel.exceptions import MultipleNodesReturned from neomodel.match_q import Q, QBase -from neomodel.properties import AliasProperty +from neomodel.properties import AliasProperty, ArrayProperty from neomodel.sync_.core import StructuredNode, db from neomodel.util import INCOMING, OUTGOING @@ -257,13 +257,11 @@ def process_filter_args(cls, kwargs): def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj): """ Transform in operator to a cypher filter - Args: operator (str): operator to transform filter_key (str): filter key filter_value (str): filter value property_obj (object): property object - Returns: tuple: operator, deflated_value """ @@ -283,11 +281,9 @@ def transform_in_operator_to_filter(operator, filter_key, filter_value, property def transform_null_operator_to_filter(filter_key, filter_value): """ Transform null operator to a cypher filter - Args: filter_key (str): filter key filter_value (str): filter value - Returns: tuple: operator, deflated_value """ @@ -303,13 +299,11 @@ def transform_regex_operator_to_filter( ): """ Transform regex operator to a cypher filter - Args: operator (str): operator to transform filter_key (str): filter key filter_value (str): filter value property_obj (object): property object - Returns: tuple: operator, deflated_value """ @@ -680,15 +674,7 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): ) else: place_holder = self._register_place_holder(ident + "_" + prop) - if operator == _SPECIAL_OPERATOR_ARRAY_IN: - statement = operator.format( - ident=ident, - prop=prop, - val=f"${place_holder}", - ) - statement = f"{'NOT' if negate else ''} {statement}" - else: - statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}" + statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}" self._query_params[place_holder] = val stmts.append(statement) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 87c1ca8e..97f4b044 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -5,6 +5,7 @@ from neomodel import ( INCOMING, + ArrayProperty, AsyncRelationshipFrom, AsyncRelationshipTo, AsyncStructuredNode, @@ -37,6 +38,7 @@ class Supplier(AsyncStructuredNode): class Species(AsyncStructuredNode): name = StringProperty() + tags = ArrayProperty(StringProperty(), default=list) coffees = AsyncRelationshipFrom( "Coffee", "COFFEE SPECIES", model=AsyncStructuredRel ) @@ -567,3 +569,23 @@ async def test_fetch_relations(): assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( name="Sainsburys" ) + + +@mark_async_test +async def test_in_filter_with_array_property(): + tags = ["smoother", "sweeter", "chocolate", "sugar"] + no_match = ["organic"] + arabica = await Species(name="Arabica", tags=tags).save() + + assert arabica in await Species.nodes.filter( + tags__in=tags + ), "Species not found by tags given" + assert arabica in await Species.nodes.filter( + Q(tags__in=tags) + ), "Species not found with Q by tags given" + assert arabica not in await Species.nodes.filter( + ~Q(tags__in=tags) + ), "Species found by tags given in negated query" + assert arabica not in await Species.nodes.filter( + tags__in=no_match + ), "Species found by tags with not match tags given" diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index cc080b1f..399c15fe 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -33,8 +33,8 @@ class Supplier(StructuredNode): class Species(StructuredNode): name = StringProperty() - coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) tags = ArrayProperty(StringProperty(), default=list) + coffees = RelationshipFrom("Coffee", "COFFEE SPECIES", model=StructuredRel) class Coffee(StructuredNode): @@ -558,3 +558,23 @@ def test_fetch_relations(): assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( name="Sainsburys" ) + + +@mark_sync_test +def test_in_filter_with_array_property(): + tags = ["smoother", "sweeter", "chocolate", "sugar"] + no_match = ["organic"] + arabica = Species(name="Arabica", tags=tags).save() + + assert arabica in Species.nodes.filter( + tags__in=tags + ), "Species not found by tags given" + assert arabica in Species.nodes.filter( + Q(tags__in=tags) + ), "Species not found with Q by tags given" + assert arabica not in Species.nodes.filter( + ~Q(tags__in=tags) + ), "Species found by tags given in negated query" + assert arabica not in Species.nodes.filter( + tags__in=no_match + ), "Species found by tags with not match tags given" From 8cf88bcf4293106f1ffd28f11f1ef7c360e4a743 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Fri, 12 Apr 2024 14:35:52 +0200 Subject: [PATCH 71/73] Update changelog --- Changelog | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Changelog b/Changelog index dd11f1eb..290f939e 100644 --- a/Changelog +++ b/Changelog @@ -1,7 +1,8 @@ Version 5.3.0 2024-04 * Add async support -* Bumps neo4j (driver) to 5.19.0 * Breaking change : config.AUTO_INSTALL_LABELS has been removed. Please use the neomodel_install_labels script instead +* Bumps neo4j (driver) to 5.19.0 +* Various improvement : functools wrap to TransactionProxy, fix node equality check, q filter for IN in arrays. Thanks to @giosava94, @OlehChyhyryn, @icapora Version 5.2.1 2023-12 * Add options to inspection script to skip heavy operations - rel props or cardinality inspection #767 From 5b2b84cd5720866d5fef8223c124e5090913dcab Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 15 Apr 2024 11:33:00 +0200 Subject: [PATCH 72/73] #776 Fix inflate on db_property --- neomodel/async_/core.py | 22 +---- neomodel/async_/match.py | 4 +- neomodel/async_/property_manager.py | 33 ++++++- neomodel/async_/relationship.py | 10 +-- neomodel/async_/relationship_manager.py | 4 +- neomodel/contrib/async_/semi_structured.py | 64 ++++++------- neomodel/contrib/sync_/semi_structured.py | 5 +- neomodel/sync_/core.py | 22 +---- neomodel/sync_/property_manager.py | 33 ++++++- neomodel/sync_/relationship_manager.py | 2 +- .../test_contrib/test_semi_structured.py | 49 +++++++++- test/async_/test_properties.py | 89 ++++++++++++++++--- .../test_contrib/test_semi_structured.py | 12 ++- test/sync_/test_properties.py | 55 +++++++++++- 14 files changed, 301 insertions(+), 103 deletions(-) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py index a67b2280..70b619f3 100644 --- a/neomodel/async_/core.py +++ b/neomodel/async_/core.py @@ -40,7 +40,6 @@ from neomodel.hooks import hooks from neomodel.properties import Property from neomodel.util import ( - _get_node_properties, _UnsavedNode, classproperty, deprecated, @@ -794,7 +793,7 @@ async def _create_relationship_constraint( async def _install_node(self, cls, name, property, quiet, stdout): # Create indexes and constraints for node property - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if property.index: if not quiet: stdout.write( @@ -821,7 +820,7 @@ async def _install_relationship(self, cls, relationship, quiet, stdout): for prop_name, property in relationship_cls.defined_properties( aliases=False, rels=False ).items(): - db_property = property.db_property or prop_name + db_property = property.get_db_property_name(prop_name) if property.index: if not quiet: stdout.write( @@ -1196,7 +1195,7 @@ async def _build_merge_query( n_merge_labels = ":".join(cls.inherited_labels()) n_merge_prm = ", ".join( ( - f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" for p in cls.__required_properties__ ) ) @@ -1419,20 +1418,7 @@ def inflate(cls, node): snode = cls() snode.element_id_property = node else: - node_properties = _get_node_properties(node) - props = {} - for key, prop in cls.__all_properties__: - # map property name from database to object property - db_property = prop.db_property or key - - if db_property in node_properties: - props[key] = prop.inflate(node_properties[db_property], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - - snode = cls(**props) + snode = super().inflate(node) snode.element_id_property = node.element_id return snode diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py index 12bf6ed2..f28d69be 100644 --- a/neomodel/async_/match.py +++ b/neomodel/async_/match.py @@ -247,7 +247,9 @@ def process_filter_args(cls, kwargs): ) # map property to correct property name in the database - db_property = cls.defined_properties(rels=False)[prop].db_property or prop + db_property = cls.defined_properties(rels=False)[prop].get_db_property_name( + prop + ) output[db_property] = (operator, deflated_value) diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py index b9401dab..6eb18b3f 100644 --- a/neomodel/async_/property_manager.py +++ b/neomodel/async_/property_manager.py @@ -73,10 +73,18 @@ def __properties__(self): @classmethod def deflate(cls, properties, obj=None, skip_empty=False): - # deflate dict ready to be stored + """ + Deflate the properties of a PropertyManager subclass (a user-defined StructuredNode or StructuredRel) so that it + can be put into a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) for storage. properties + can be constructed manually, or fetched from a PropertyManager subclass using __properties__. + + Includes mapping from python class attribute name -> database property name (see Property.db_property). + + Ignores any properties that are not defined as python attributes in the class definition. + """ deflated = {} for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if properties.get(name) is not None: deflated[db_property] = property.deflate(properties[name], obj) elif property.has_default: @@ -87,6 +95,27 @@ def deflate(cls, properties, obj=None, skip_empty=False): deflated[db_property] = None return deflated + @classmethod + def inflate(cls, graph_entity): + """ + Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance + of cls. + Includes mapping from database property name (see Property.db_property) -> python class attribute name. + Ignores any properties that are not defined as python attributes in the class definition. + """ + inflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.get_db_property_name(name) + if db_property in graph_entity: + inflated[name] = property.inflate( + graph_entity[db_property], graph_entity + ) + elif property.has_default: + inflated[name] = property.default_value() + else: + inflated[name] = None + return cls(**inflated) + @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): from neomodel.async_.relationship_manager import AsyncRelationshipDefinition diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py index 5653137a..818bc3a2 100644 --- a/neomodel/async_/relationship.py +++ b/neomodel/async_/relationship.py @@ -153,15 +153,7 @@ def inflate(cls, rel): :param rel: :return: StructuredRel """ - props = {} - for key, prop in cls.defined_properties(aliases=False, rels=False).items(): - if key in rel: - props[key] = prop.inflate(rel[key], obj=rel) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - srel = cls(**props) + srel = super().inflate(rel) srel._start_node_element_id_property = rel.start_node.element_id srel._end_node_element_id_property = rel.end_node.element_id srel.element_id_property = rel.element_id diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py index 9c5e7398..4182d143 100644 --- a/neomodel/async_/relationship_manager.py +++ b/neomodel/async_/relationship_manager.py @@ -16,8 +16,8 @@ EITHER, INCOMING, OUTGOING, - _get_node_properties, enumerate_traceback, + get_graph_entity_properties, ) # basestring python 3.x fallback @@ -246,7 +246,7 @@ async def reconnect(self, old_node, new_node): {"old": old_node_element_id}, ) if result: - node_properties = _get_node_properties(result[0][0]) + node_properties = get_graph_entity_properties(result[0][0]) existing_properties = node_properties.keys() else: raise NotConnected("reconnect", self.source, old_node) diff --git a/neomodel/contrib/async_/semi_structured.py b/neomodel/contrib/async_/semi_structured.py index c333ae0e..810ea10f 100644 --- a/neomodel/contrib/async_/semi_structured.py +++ b/neomodel/contrib/async_/semi_structured.py @@ -1,6 +1,6 @@ from neomodel.async_.core import AsyncStructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict -from neomodel.util import _get_node_properties +from neomodel.util import get_graph_entity_properties class AsyncSemiStructuredNode(AsyncStructuredNode): @@ -25,40 +25,44 @@ def hello(self): @classmethod def inflate(cls, node): - # support lazy loading - if isinstance(node, str) or isinstance(node, int): - snode = cls() - snode.element_id_property = node - else: - props = {} - node_properties = {} - for key, prop in cls.__all_properties__: - node_properties = _get_node_properties(node) - if key in node_properties: - props[key] = prop.inflate(node_properties[key], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - # handle properties not defined on the class - for free_key in (x for x in node_properties if x not in props): - if hasattr(cls, free_key): - raise InflateConflict( - cls, free_key, node_properties[free_key], node.element_id - ) - props[free_key] = node_properties[free_key] + # Inflate all properties registered in the class definition + snode = super().inflate(node) - snode = cls(**props) - snode.element_id_property = node.element_id + # Node can be a string or int for lazy loading (See StructuredNode.inflate). In that case, `node` has nothing + # that can be unpacked further. + if not hasattr(node, "items"): + return snode + + # Inflate all extra properties not registered in the class definition + registered_db_property_names = { + property.get_db_property_name(name) + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + } + extra_keys = node.keys() - registered_db_property_names + for extra_key in extra_keys: + value = node[extra_key] + if hasattr(cls, extra_key): + raise InflateConflict(cls, extra_key, value, snode.element_id) + setattr(snode, extra_key, value) return snode @classmethod def deflate(cls, node_props, obj=None, skip_empty=False): + # Deflate all properties registered in the class definition deflated = super().deflate(node_props, obj, skip_empty=skip_empty) - for key in [k for k in node_props if k not in deflated]: - if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): - raise DeflateConflict(cls, key, deflated[key], obj.element_id) - node_props.update(deflated) - return node_props + # Deflate all extra properties not registered in the class definition + registered_names = cls.defined_properties(aliases=False, rels=False).keys() + extra_keys = node_props.keys() - registered_names + for extra_key in extra_keys: + value = node_props[extra_key] + if hasattr(cls, extra_key): + raise DeflateConflict( + cls, extra_key, value, node_props.get("element_id") + ) + deflated[extra_key] = node_props[extra_key] + + return deflated diff --git a/neomodel/contrib/sync_/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py index 39b31488..97c43c39 100644 --- a/neomodel/contrib/sync_/semi_structured.py +++ b/neomodel/contrib/sync_/semi_structured.py @@ -1,6 +1,6 @@ from neomodel.exceptions import DeflateConflict, InflateConflict from neomodel.sync_.core import StructuredNode -from neomodel.util import _get_node_properties +from neomodel.util import get_graph_entity_properties class SemiStructuredNode(StructuredNode): @@ -53,9 +53,6 @@ def inflate(cls, node): def deflate(cls, node_props, obj=None, skip_empty=False): # Deflate all properties registered in the class definition deflated = super().deflate(node_props, obj, skip_empty=skip_empty) - for key in [k for k in node_props if k not in deflated]: - if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): - raise DeflateConflict(cls, key, deflated[key], obj.element_id) # Deflate all extra properties not registered in the class definition registered_names = cls.defined_properties(aliases=False, rels=False).keys() diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py index 6af7c5ab..0ed3e4e3 100644 --- a/neomodel/sync_/core.py +++ b/neomodel/sync_/core.py @@ -40,7 +40,6 @@ from neomodel.properties import Property from neomodel.sync_.property_manager import PropertyManager from neomodel.util import ( - _get_node_properties, _UnsavedNode, classproperty, deprecated, @@ -792,7 +791,7 @@ def _create_relationship_constraint( def _install_node(self, cls, name, property, quiet, stdout): # Create indexes and constraints for node property - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if property.index: if not quiet: stdout.write( @@ -819,7 +818,7 @@ def _install_relationship(self, cls, relationship, quiet, stdout): for prop_name, property in relationship_cls.defined_properties( aliases=False, rels=False ).items(): - db_property = property.db_property or prop_name + db_property = property.get_db_property_name(prop_name) if property.index: if not quiet: stdout.write( @@ -1192,7 +1191,7 @@ def _build_merge_query( n_merge_labels = ":".join(cls.inherited_labels()) n_merge_prm = ", ".join( ( - f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + f"{getattr(cls, p).get_db_property_name(p)}: params.create.{getattr(cls, p).get_db_property_name(p)}" for p in cls.__required_properties__ ) ) @@ -1415,20 +1414,7 @@ def inflate(cls, node): snode = cls() snode.element_id_property = node else: - node_properties = _get_node_properties(node) - props = {} - for key, prop in cls.__all_properties__: - # map property name from database to object property - db_property = prop.db_property or key - - if db_property in node_properties: - props[key] = prop.inflate(node_properties[db_property], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - - snode = cls(**props) + snode = super().inflate(node) snode.element_id_property = node.element_id return snode diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py index 85452f0b..2b4eeaaa 100644 --- a/neomodel/sync_/property_manager.py +++ b/neomodel/sync_/property_manager.py @@ -73,10 +73,18 @@ def __properties__(self): @classmethod def deflate(cls, properties, obj=None, skip_empty=False): - # deflate dict ready to be stored + """ + Deflate the properties of a PropertyManager subclass (a user-defined StructuredNode or StructuredRel) so that it + can be put into a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) for storage. properties + can be constructed manually, or fetched from a PropertyManager subclass using __properties__. + + Includes mapping from python class attribute name -> database property name (see Property.db_property). + + Ignores any properties that are not defined as python attributes in the class definition. + """ deflated = {} for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name + db_property = property.get_db_property_name(name) if properties.get(name) is not None: deflated[db_property] = property.deflate(properties[name], obj) elif property.has_default: @@ -87,6 +95,27 @@ def deflate(cls, properties, obj=None, skip_empty=False): deflated[db_property] = None return deflated + @classmethod + def inflate(cls, graph_entity): + """ + Inflate the properties of a neo4j.graph.Entity (a neo4j.graph.Node or neo4j.graph.Relationship) into an instance + of cls. + Includes mapping from database property name (see Property.db_property) -> python class attribute name. + Ignores any properties that are not defined as python attributes in the class definition. + """ + inflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.get_db_property_name(name) + if db_property in graph_entity: + inflated[name] = property.inflate( + graph_entity[db_property], graph_entity + ) + elif property.has_default: + inflated[name] = property.default_value() + else: + inflated[name] = None + return cls(**inflated) + @classmethod def defined_properties(cls, aliases=True, properties=True, rels=True): from neomodel.sync_.relationship_manager import RelationshipDefinition diff --git a/neomodel/sync_/relationship_manager.py b/neomodel/sync_/relationship_manager.py index d8ed0620..f975ec5f 100644 --- a/neomodel/sync_/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -11,8 +11,8 @@ EITHER, INCOMING, OUTGOING, - _get_node_properties, enumerate_traceback, + get_graph_entity_properties, ) # basestring python 3.x fallback diff --git a/test/async_/test_contrib/test_semi_structured.py b/test/async_/test_contrib/test_semi_structured.py index 3b88fcad..e4e6bbd0 100644 --- a/test/async_/test_contrib/test_semi_structured.py +++ b/test/async_/test_contrib/test_semi_structured.py @@ -1,6 +1,14 @@ from test._async_compat import mark_async_test -from neomodel import IntegerProperty, StringProperty +import pytest + +from neomodel import ( + DeflateConflict, + InflateConflict, + IntegerProperty, + StringProperty, + adb, +) from neomodel.contrib import AsyncSemiStructuredNode @@ -33,3 +41,42 @@ async def test_save_to_model_with_extras(): async def test_save_empty_model(): dummy = Dummy() assert await dummy.save() + + +@mark_async_test +async def test_inflate_conflict(): + class PersonForInflateTest(AsyncSemiStructuredNode): + name = StringProperty() + age = IntegerProperty() + + def hello(self): + print("Hi my names " + self.name) + + # An ok model + props = {"name": "Jim", "age": 8, "weight": 11} + await adb.cypher_query("CREATE (n:PersonForInflateTest $props)", {"props": props}) + jim = await PersonForInflateTest.nodes.get(name="Jim") + assert jim.name == "Jim" + assert jim.age == 8 + assert jim.weight == 11 + + # A model that conflicts on `hello` + props = {"name": "Tim", "age": 8, "hello": "goodbye"} + await adb.cypher_query("CREATE (n:PersonForInflateTest $props)", {"props": props}) + with pytest.raises(InflateConflict): + await PersonForInflateTest.nodes.get(name="Tim") + + +@mark_async_test +async def test_deflate_conflict(): + class PersonForDeflateTest(AsyncSemiStructuredNode): + name = StringProperty() + age = IntegerProperty() + + def hello(self): + print("Hi my names " + self.name) + + tim = await PersonForDeflateTest(name="Tim", age=8, weight=11).save() + tim.hello = "Hi" + with pytest.raises(DeflateConflict): + await tim.save() diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py index 0679cf89..8102f71d 100644 --- a/test/async_/test_properties.py +++ b/test/async_/test_properties.py @@ -4,7 +4,8 @@ from pytest import mark, raises from pytz import timezone -from neomodel import AsyncStructuredNode, adb +from neomodel import AsyncRelationship, AsyncStructuredNode, AsyncStructuredRel, adb +from neomodel.contrib import AsyncSemiStructuredNode from neomodel.exceptions import ( DeflateError, InflateError, @@ -24,7 +25,7 @@ StringProperty, UniqueIdProperty, ) -from neomodel.util import _get_node_properties +from neomodel.util import get_graph_entity_properties class FooBar: @@ -227,22 +228,33 @@ class DefaultTestValueThree(AsyncStructuredNode): assert x.uid == "123" +class TestDBNamePropertyRel(AsyncStructuredRel): + known_for = StringProperty(db_property="knownFor") + + +# This must be defined outside of the test, otherwise the `Relationship` definition cannot look up +# `TestDBNamePropertyNode` +class TestDBNamePropertyNode(AsyncStructuredNode): + name_ = StringProperty(db_property="name") + knows = AsyncRelationship( + "TestDBNamePropertyNode", "KNOWS", model=TestDBNamePropertyRel + ) + + @mark_async_test async def test_independent_property_name(): - class TestDBNamePropertyNode(AsyncStructuredNode): - name_ = StringProperty(db_property="name") - + # -- test node -- x = TestDBNamePropertyNode() x.name_ = "jim" await x.save() # check database property name on low level results, meta = await adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" + assert "name_" not in node_properties - node_properties = _get_node_properties(results[0][0]) - assert not "name_" in node_properties + # check python class property name at a high level assert not hasattr(x, "name") assert hasattr(x, "name_") assert (await TestDBNamePropertyNode.nodes.filter(name_="jim").all())[ @@ -250,9 +262,66 @@ class TestDBNamePropertyNode(AsyncStructuredNode): ].name_ == x.name_ assert (await TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ + # -- test relationship -- + + r = await x.knows.connect(x) + r.known_for = "10 years" + await r.save() + + # check database property name on low level + results, meta = await adb.cypher_query( + "MATCH (:TestDBNamePropertyNode)-[r:KNOWS]->(:TestDBNamePropertyNode) RETURN r" + ) + rel_properties = get_graph_entity_properties(results[0][0]) + assert rel_properties["knownFor"] == "10 years" + assert not "known_for" in node_properties + + # check python class property name at a high level + assert not hasattr(r, "knownFor") + assert hasattr(r, "known_for") + rel = await x.knows.relationship(x) + assert rel.known_for == r.known_for + + # -- cleanup -- + await x.delete() +@mark_async_test +async def test_independent_property_name_for_semi_structured(): + class TestDBNamePropertySemiStructuredNode(AsyncSemiStructuredNode): + title_ = StringProperty(db_property="title") + + semi = TestDBNamePropertySemiStructuredNode(title_="sir", extra="data") + await semi.save() + + # check database property name on low level + results, meta = await adb.cypher_query( + "MATCH (n:TestDBNamePropertySemiStructuredNode) RETURN n" + ) + node_properties = get_graph_entity_properties(results[0][0]) + assert node_properties["title"] == "sir" + # assert "title_" not in node_properties + assert node_properties["extra"] == "data" + + # check python class property name at a high level + assert hasattr(semi, "title_") + assert not hasattr(semi, "title") + assert hasattr(semi, "extra") + from_filter = ( + await TestDBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all() + )[0] + assert from_filter.title_ == "sir" + # assert not hasattr(from_filter, "title") + assert from_filter.extra == "data" + from_get = await TestDBNamePropertySemiStructuredNode.nodes.get(title_="sir") + assert from_get.title_ == "sir" + # assert not hasattr(from_get, "title") + assert from_get.extra == "data" + + await semi.delete() + + @mark_async_test async def test_independent_property_name_get_or_create(): class TestNode(AsyncStructuredNode): @@ -266,7 +335,7 @@ class TestNode(AsyncStructuredNode): # check database property name on low level results, _ = await adb.cypher_query("MATCH (n:TestNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties @@ -417,7 +486,7 @@ class ConstrainedTestNode(AsyncStructuredNode): # check database property name on low level results, meta = await adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["unique_required_property"] == "unique and required" # delete node afterwards diff --git a/test/sync_/test_contrib/test_semi_structured.py b/test/sync_/test_contrib/test_semi_structured.py index c24d29bc..feb717e8 100644 --- a/test/sync_/test_contrib/test_semi_structured.py +++ b/test/sync_/test_contrib/test_semi_structured.py @@ -1,6 +1,14 @@ from test._async_compat import mark_sync_test -from neomodel import IntegerProperty, StringProperty +import pytest + +from neomodel import ( + DeflateConflict, + InflateConflict, + IntegerProperty, + StringProperty, + db, +) from neomodel.contrib import SemiStructuredNode @@ -35,6 +43,7 @@ def test_save_empty_model(): assert dummy.save() +@mark_sync_test def test_inflate_conflict(): class PersonForInflateTest(SemiStructuredNode): name = StringProperty() @@ -58,6 +67,7 @@ def hello(self): PersonForInflateTest.nodes.get(name="Tim") +@mark_sync_test def test_deflate_conflict(): class PersonForDeflateTest(SemiStructuredNode): name = StringProperty() diff --git a/test/sync_/test_properties.py b/test/sync_/test_properties.py index 639028d6..0f9a162f 100644 --- a/test/sync_/test_properties.py +++ b/test/sync_/test_properties.py @@ -4,7 +4,8 @@ from pytest import mark, raises from pytz import timezone -from neomodel import StructuredNode, db +from neomodel import Relationship, StructuredNode, StructuredRel, db +from neomodel.contrib import SemiStructuredNode from neomodel.exceptions import ( DeflateError, InflateError, @@ -227,10 +228,20 @@ class DefaultTestValueThree(StructuredNode): assert x.uid == "123" +class TestDBNamePropertyRel(StructuredRel): + known_for = StringProperty(db_property="knownFor") + + +# This must be defined outside of the test, otherwise the `Relationship` definition cannot look up +# `TestDBNamePropertyNode` +class TestDBNamePropertyNode(StructuredNode): + name_ = StringProperty(db_property="name") + knows = Relationship("TestDBNamePropertyNode", "KNOWS", model=TestDBNamePropertyRel) + + @mark_sync_test def test_independent_property_name(): # -- test node -- - x = TestDBNamePropertyNode() x.name_ = "jim" x.save() @@ -264,13 +275,49 @@ def test_independent_property_name(): # check python class property name at a high level assert not hasattr(r, "knownFor") assert hasattr(r, "known_for") - assert x.knows.relationship(x).known_for == r.known_for + rel = x.knows.relationship(x) + assert rel.known_for == r.known_for # -- cleanup -- x.delete() +@mark_sync_test +def test_independent_property_name_for_semi_structured(): + class TestDBNamePropertySemiStructuredNode(SemiStructuredNode): + title_ = StringProperty(db_property="title") + + semi = TestDBNamePropertySemiStructuredNode(title_="sir", extra="data") + semi.save() + + # check database property name on low level + results, meta = db.cypher_query( + "MATCH (n:TestDBNamePropertySemiStructuredNode) RETURN n" + ) + node_properties = get_graph_entity_properties(results[0][0]) + assert node_properties["title"] == "sir" + # assert "title_" not in node_properties + assert node_properties["extra"] == "data" + + # check python class property name at a high level + assert hasattr(semi, "title_") + assert not hasattr(semi, "title") + assert hasattr(semi, "extra") + from_filter = ( + TestDBNamePropertySemiStructuredNode.nodes.filter(title_="sir").all() + )[0] + assert from_filter.title_ == "sir" + # assert not hasattr(from_filter, "title") + assert from_filter.extra == "data" + from_get = TestDBNamePropertySemiStructuredNode.nodes.get(title_="sir") + assert from_get.title_ == "sir" + # assert not hasattr(from_get, "title") + assert from_get.extra == "data" + + semi.delete() + + @mark_sync_test def test_independent_property_name_get_or_create(): class TestNode(StructuredNode): @@ -284,7 +331,7 @@ class TestNode(StructuredNode): # check database property name on low level results, _ = db.cypher_query("MATCH (n:TestNode) RETURN n") - node_properties = _get_node_properties(results[0][0]) + node_properties = get_graph_entity_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties From 0aeac32b83a7e3e44c7efbf14a9ee1d349d25d0d Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Mon, 15 Apr 2024 11:33:30 +0200 Subject: [PATCH 73/73] Update changelog --- Changelog | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Changelog b/Changelog index 290f939e..37758648 100644 --- a/Changelog +++ b/Changelog @@ -2,7 +2,7 @@ Version 5.3.0 2024-04 * Add async support * Breaking change : config.AUTO_INSTALL_LABELS has been removed. Please use the neomodel_install_labels script instead * Bumps neo4j (driver) to 5.19.0 -* Various improvement : functools wrap to TransactionProxy, fix node equality check, q filter for IN in arrays. Thanks to @giosava94, @OlehChyhyryn, @icapora +* Various improvement : functools wrap to TransactionProxy, fix node equality check, q filter for IN in arrays, fix inflate on db_property. Thanks to @giosava94, @OlehChyhyryn, @icapora, @j-krose Version 5.2.1 2023-12 * Add options to inspection script to skip heavy operations - rel props or cardinality inspection #767