diff --git a/.gitignore b/.gitignore index a79c277d..562969f3 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,6 @@ development.env .ropeproject \#*\# .eggs -bin lib .vscode pyvenv.cfg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81b274db..22db6a97 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,18 @@ repos: - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - repo: https://github.com/PyCQA/isort rev: 5.11.5 hooks: - id: isort - # - repo: local - # hooks: - # - id: pylint - # name: pylint - # entry: pylint neomodel/ - # language: system - # always_run: true - # pass_filenames: false \ No newline at end of file + args: ["--profile", "black"] + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + - repo: local + hooks: + - id: unasync + name: unasync + entry: bin/make-unasync + language: python + files: "^(neomodel/async_|test/async_)/.*" + additional_dependencies: [unasync, isort, black] \ No newline at end of file diff --git a/README.md b/README.md index 5d4566a5..0bd3478e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ GitHub repo found at . # Documentation -(Needs an update, but) Available on +Available on [readthedocs](http://neomodel.readthedocs.org). # Upcoming breaking changes notice - \>=5.3 @@ -47,7 +47,7 @@ support for Python 3.12. Another source of upcoming breaking changes is the addition async support to neomodel. No date is set yet, but the work has progressed a lot in the past weeks ; -and it will be part of a major release (potentially 6.0 to avoid misunderstandings). +and it will be part of a major release. You can see the progress in [this branch](https://github.com/neo4j-contrib/neomodel/tree/task/async). Finally, we are looking at refactoring some standalone methods into the @@ -67,6 +67,15 @@ To install from github: $ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev +# Performance comparison + +You can find some performance tests made using Locust [in this repo](https://github.com/mariusconjeaud/neomodel-locust). + +Two learnings from this : + +* The wrapping of the driver made by neomodel is very thin performance-wise : it does not add a lot of overhead ; +* When used in a concurrent fashion, async neomodel is faster than concurrent sync neomodel, and a lot of faster than serial queries. + # Contributing Ideas, bugs, tests and pull requests always welcome. Please use @@ -112,3 +121,42 @@ against all supported Python interpreters and neo4j versions: : # in the project's root folder: $ sh ./tests-with-docker-compose.sh + +## Developing with async + +### Transpiling async -> sync + +We use [this great library](https://github.com/python-trio/unasync) to automatically transpile async code into its sync version. + +In other words, when contributing to neomodel, only update the `async` code in `neomodel/async_`, then run : : + + bin/make-unasync + isort . + black . + +Note that you can also use the pre-commit hooks for this. + +### Specific async/sync code +This transpiling script mainly does two things : + +- It removes the await keywords, and the Async prefixes in class names +- It does some specific replacements, like `adb`->`db`, `mark_async_test`->`mark_sync_test` + +It might be that your code should only be run for `async`, or `sync` ; or you want different stubs to be run for `async` vs `sync`. +You can use the following utility function for this - taken from the official [Neo4j python driver code](https://github.com/neo4j/neo4j-python-driver) : + + # neomodel/async_/core.py + from neomodel._async_compat.util import AsyncUtil + + # AsyncUtil.is_async_code is always True + if AsyncUtil.is_async_code: + # Specific async code + # This one gets run when in async mode + assert await Coffee.nodes.check_contains(2) + else: + # Specific sync code + # This one gest run when in sync mode + assert 2 in Coffee.nodes + +You can check [test_match_api](test/async_/test_match_api.py) for some good examples, and how it's transpiled into sync. + diff --git a/bin/make-unasync b/bin/make-unasync new file mode 100755 index 00000000..96375526 --- /dev/null +++ b/bin/make-unasync @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 + +import collections +import errno +import os +import re +import sys +import tokenize as std_tokenize +from pathlib import Path + +import black +import isort +import isort.files +import unasync + +ROOT_DIR = Path(__file__).parents[1].absolute() +ASYNC_DIR = ROOT_DIR / "neomodel" / "async_" +SYNC_DIR = ROOT_DIR / "neomodel" / "sync_" +ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_" +SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_" +ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_" +SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_" +INTEGRATION_TEST_EXCLUSION_LIST = ["conftest.py"] +UNASYNC_SUFFIX = ".unasync" + +PY_FILE_EXTENSIONS = {".py"} + + +# copy from unasync for customization ----------------------------------------- +# https://github.com/python-trio/unasync +# License: MIT or Apache2 + + +Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"]) + + +def _makedirs_existok(dir): + try: + os.makedirs(dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def _get_tokens(f): + if sys.version_info[0] == 2: + for tok in std_tokenize.generate_tokens(f.readline): + type_, string, start, end, line = tok + yield Token(type_, string, start, end, line) + else: + for tok in std_tokenize.tokenize(f.readline): + if tok.type == std_tokenize.ENCODING: + continue + yield tok + + +def _tokenize(f): + last_end = (1, 0) + for tok in _get_tokens(f): + if last_end[0] < tok.start[0]: + yield "", std_tokenize.STRING, " \\\n" + last_end = (tok.start[0], 0) + + space = "" + if tok.start > last_end: + assert tok.start[0] == last_end[0] + space = " " * (tok.start[1] - last_end[1]) + yield space, tok.type, tok.string + + last_end = tok.end + if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]: + last_end = (tok.end[0] + 1, 0) + elif sys.version_info >= (3, 12) and tok.type == std_tokenize.FSTRING_MIDDLE: + last_end = ( + last_end[0], + last_end[1] + tok.string.count("{") + tok.string.count("}"), + ) + + +def _untokenize(tokens): + return "".join(space + tokval for space, tokval in tokens) + + +# end of copy ----------------------------------------------------------------- + + +class CustomRule(unasync.Rule): + def __init__(self, *args, **kwargs): + super(CustomRule, self).__init__(*args, **kwargs) + self.out_files = [] + + def _unasync_tokens(self, tokens): + # copy from unasync to fix handling of multiline strings + # https://github.com/python-trio/unasync + # License: MIT or Apache2 + + used_space = None + for space, toknum, tokval in tokens: + if tokval in ["async", "await"]: + # When removing async or await, we want to use the whitespace + # that was before async/await before the next token so that + # `print(await stuff)` becomes `print(stuff)` and not + # `print( stuff)` + used_space = space + else: + if toknum == std_tokenize.NAME: + tokval = self._unasync_name(tokval) + elif toknum == std_tokenize.STRING: + if tokval[0] == tokval[1] and len(tokval) > 2: + # multiline string (`"""..."""` or `'''...'''`) + left_quote, name, right_quote = ( + tokval[:3], + tokval[3:-3], + tokval[-3:], + ) + else: + # simple string (`"..."` or `'...'`) + left_quote, name, right_quote = ( + tokval[:1], + tokval[1:-1], + tokval[-1:], + ) + tokval = left_quote + self._unasync_string(name) + right_quote + elif ( + sys.version_info >= (3, 12) + and toknum == std_tokenize.FSTRING_MIDDLE + ): + tokval = tokval.replace("{", "{{").replace("}", "}}") + tokval = self._unasync_string(tokval) + if used_space is None: + used_space = space + yield (used_space, tokval) + used_space = None + + def _unasync_string(self, name): + start = 0 + end = 1 + out = "" + while end < len(name): + sub_name = name[start:end] + if sub_name.isidentifier(): + end += 1 + else: + if end == start + 1: + out += sub_name + start += 1 + end += 1 + else: + out += self._unasync_name(name[start : (end - 1)]) + start = end - 1 + + sub_name = name[start:] + if sub_name.isidentifier(): + out += self._unasync_name(name[start:]) + else: + out += sub_name + + # very boiled down unasync version that removes "async" and "await" + # substrings. + out = re.subn( + r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out, flags=re.MULTILINE + )[0] + # Convert doc-reference names from 'async-xyz' to 'xyz' + out = re.subn(r":ref:`async-", ":ref:`", out)[0] + return out + + def _unasync_prefix(self, name): + # Convert class names from 'AsyncXyz' to 'Xyz' + if len(name) > 5 and name.startswith("Async") and name[5].isupper(): + return name[5:] + # Convert class names from '_AsyncXyz' to '_Xyz' + elif len(name) > 6 and name.startswith("_Async") and name[6].isupper(): + return "_" + name[6:] + # Convert variable/method/function names from 'async_xyz' to 'xyz' + elif len(name) > 6 and name.startswith("async_"): + return name[6:] + return name + + def _unasync_name(self, name): + # copy from unasync to customize renaming rules + # https://github.com/python-trio/unasync + # License: MIT or Apache2 + if name in self.token_replacements: + return self.token_replacements[name] + return self._unasync_prefix(name) + + def _unasync_file(self, filepath): + # copy from unasync to append file suffix to out path + # https://github.com/python-trio/unasync + # License: MIT or Apache2 + with open(filepath, "rb") as f: + write_kwargs = {} + if sys.version_info[0] >= 3: + encoding, _ = std_tokenize.detect_encoding(f.readline) + write_kwargs["encoding"] = encoding + f.seek(0) + tokens = _tokenize(f) + tokens = self._unasync_tokens(tokens) + result = _untokenize(tokens) + outfile_path = filepath.replace(self.fromdir, self.todir) + outfile_path += UNASYNC_SUFFIX + self.out_files.append(outfile_path) + _makedirs_existok(os.path.dirname(outfile_path)) + with open(outfile_path, "w", **write_kwargs) as f: + print(result, file=f, end="") + + +def apply_unasync(files): + """Generate sync code from async code.""" + + additional_main_replacements = { + "adb": "db", + "async_": "sync_", + "check_bool": "__bool__", + "check_nonzero": "__nonzero__", + "check_contains": "__contains__", + "get_item": "__getitem__", + "get_len": "__len__", + } + additional_test_replacements = { + "async_": "sync_", + "check_bool": "__bool__", + "check_nonzero": "__nonzero__", + "check_contains": "__contains__", + "get_item": "__getitem__", + "get_len": "__len__", + "adb": "db", + "mark_async_test": "mark_sync_test", + "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture", + } + rules = [ + CustomRule( + fromdir=str(ASYNC_DIR), + todir=str(SYNC_DIR), + additional_replacements=additional_main_replacements, + ), + CustomRule( + fromdir=str(ASYNC_CONTRIB_DIR), + todir=str(SYNC_CONTRIB_DIR), + additional_replacements=additional_main_replacements, + ), + CustomRule( + fromdir=str(ASYNC_INTEGRATION_TEST_DIR), + todir=str(SYNC_INTEGRATION_TEST_DIR), + additional_replacements=additional_test_replacements, + ), + ] + + if not files: + paths = list(ASYNC_DIR.rglob("*")) + paths += list(ASYNC_CONTRIB_DIR.rglob("*")) + paths += [ + path + for path in ASYNC_INTEGRATION_TEST_DIR.rglob("*") + if path.name not in INTEGRATION_TEST_EXCLUSION_LIST + ] + else: + paths = [ROOT_DIR / Path(f) for f in files] + filtered_paths = [] + for path in paths: + if path.suffix in PY_FILE_EXTENSIONS: + filtered_paths.append(path) + + unasync.unasync_files(map(str, filtered_paths), rules) + + return [Path(path) for rule in rules for path in rule.out_files] + + +def apply_black(paths): + """Prettify generated sync code. + + Since keywords are removed, black might expect a different result, + especially line breaks. + """ + for path in paths: + with open(path, "r") as file: + code = file.read() + + formatted_code = black.format_str(code, mode=black.FileMode()) + + with open(path, "w") as file: + file.write(formatted_code) + + return paths + + +def apply_isort(paths): + """Sort imports in generated sync code. + + Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical + order of the import can change. + """ + isort_config = isort.Config( + settings_path=str(ROOT_DIR), quiet=True, profile="black" + ) + + for path in paths: + isort.file(str(path), config=isort_config) + + return paths + + +def apply_changes(paths): + def files_equal(path1, path2): + with open(path1, "rb") as f1: + with open(path2, "rb") as f2: + data1 = f1.read(1024) + data2 = f2.read(1024) + while data1 or data2: + if data1 != data2: + changed_paths[path1] = path2 + return False + data1 = f1.read(1024) + data2 = f2.read(1024) + return True + + changed_paths = {} + + for in_path in paths: + out_path = Path(str(in_path)[: -len(UNASYNC_SUFFIX)]) + if not out_path.is_file(): + changed_paths[in_path] = out_path + continue + if not files_equal(in_path, out_path): + changed_paths[in_path] = out_path + continue + in_path.unlink() + + for in_path, out_path in changed_paths.items(): + in_path.replace(out_path) + + return list(changed_paths.values()) + + +def main(): + files = None + if len(sys.argv) >= 1: + files = sys.argv[1:] + paths = apply_unasync(files) + paths = apply_isort(paths) + paths = apply_black(paths) + changed_paths = apply_changes(paths) + + if changed_paths: + for path in changed_paths: + print("Updated " + str(path)) + exit(1) + + +if __name__ == "__main__": + main() diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst index ed1af11b..3946ec98 100644 --- a/doc/source/configuration.rst +++ b/doc/source/configuration.rst @@ -59,7 +59,7 @@ Note that you have to manage the driver's lifecycle yourself. However, everything else is still handled by neomodel : sessions, transactions, etc... -NB : Only the synchronous driver will work in this way. The asynchronous driver is not supported yet. +NB : Only the synchronous driver will work in this way. See the next section for the preferred method, and how to pass an async driver instance. Change/Close the connection --------------------------- @@ -119,14 +119,7 @@ with something like: :: Enable automatic index and constraint creation ---------------------------------------------- -After the definition of a `StructuredNode`, Neomodel can install the corresponding -constraints and indexes at compile time. However this method is only recommended for testing:: - - from neomodel import config - # before loading your node definitions - config.AUTO_INSTALL_LABELS = True - -Neomodel also provides the :ref:`neomodel_install_labels` script for this task, +Neomodel provides the :ref:`neomodel_install_labels` script for this task, however if you want to handle this manually see below. Install indexes and constraints for a single class:: @@ -146,6 +139,9 @@ Or for an entire 'schema' :: # + Creating unique constraint for name on label User for class yourapp.models.User # ... +.. note:: + config.AUTO_INSTALL_LABELS has been removed from neomodel in version 5.3 + Require timezones on DateTimeProperty ------------------------------------- diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 32067009..df50f84d 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -39,7 +39,7 @@ labels, the `__optional_labels__` property must be defined as a list of strings: __optional_labels__ = ["SuperSaver", "SeniorDiscount"] balance = IntegerProperty(index=True) -.. warning:: The size of the node class mapping grows exponentially with optional labels. Use with some caution. +.. note:: The size of the node class mapping grows exponentially with optional labels. Use with some caution. Mixins diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst index 25e999e7..c756a6f6 100644 --- a/doc/source/getting_started.rst +++ b/doc/source/getting_started.rst @@ -22,6 +22,7 @@ Querying the graph neomodel is mainly used as an OGM (see next section), but you can also use it for direct Cypher queries : :: + from neomodel import db results, meta = db.cypher_query("RETURN 'Hello World' as message") @@ -104,7 +105,7 @@ and cardinality will be default (ZeroOrMore). Finally, relationship cardinality is guessed from the database by looking at existing relationships, so it might guess wrong on edge cases. -.. warning:: +.. note:: The script relies on the method apoc.meta.cypher.types to parse property types. So APOC must be installed on your Neo4j server for this script to work. @@ -246,7 +247,7 @@ the following syntax:: Person.nodes.all().fetch_relations('city__country', Optional('country')) -.. warning:: +.. note:: This feature is still a work in progress for extending path traversal and fecthing. It currently stops at returning the resolved objects as they are returned in Cypher. @@ -256,3 +257,60 @@ the following syntax:: If you want to go further in the resolution process, you have to develop your own parser (for now). + +Async neomodel +============== + +neomodel supports asynchronous operations using the async support of neo4j driver. The examples below take a few of the above examples, +but rewritten for async:: + + from neomodel import adb + results, meta = await adb.cypher_query("RETURN 'Hello World' as message") + +OGM with async :: + + # Note that properties do not change, but nodes and relationships now have an Async prefix + from neomodel import (AsyncStructuredNode, StringProperty, IntegerProperty, + UniqueIdProperty, AsyncRelationshipTo) + + class Country(AsyncStructuredNode): + code = StringProperty(unique_index=True, required=True) + + class City(AsyncStructuredNode): + name = StringProperty(required=True) + country = AsyncRelationshipTo(Country, 'FROM_COUNTRY') + + # Operations that interact with the database are now async + # Return all nodes + # Note that the nodes object is awaitable as is + all_nodes = await Country.nodes + + # Relationships + germany = await Country(code='DE').save() + await jim.country.connect(germany) + +Most _dunder_ methods for nodes and relationships had to be overriden to support async operations. The following methods are supported :: + + # Examples below are taken from the various tests. Please check them for more examples. + # Length + dogs_bonanza = await Dog.nodes.get_len() + # Sync equivalent - __len__ + dogs_bonanza = len(Dog.nodes) + # Note that len(Dog.nodes) is more efficient than Dog.nodes.__len__ + + # Existence + assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() + # Sync equivalent - __bool__ + assert not Customer.nodes.filter(email="jim7@aol.com") + # Also works for check_nonzero => __nonzero__ + + # Contains + assert await Coffee.nodes.check_contains(aCoffeeNode) + # Sync equivalent - __contains__ + assert aCoffeeNode in Coffee.nodes + + # Get item + assert len(list((await Coffee.nodes)[1:])) == 2 + # Sync equivalent - __getitem__ + assert len(list(Coffee.nodes[1:])) == 2 + diff --git a/doc/source/index.rst b/doc/source/index.rst index ec3372c9..1338ce27 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -9,6 +9,7 @@ An Object Graph Mapper (OGM) for the Neo4j_ graph database, built on the awesome - Enforce your schema through cardinality restrictions. - Full transaction support. - Thread safe. +- Async support. - pre/post save/delete hooks. - Django integration via django_neomodel_ @@ -40,6 +41,26 @@ To install from github:: $ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev +.. note:: + + **Breaking changes in 5.3** + + Introducing support for asynchronous programming to neomodel required to introduce some breaking changes: + + - config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` (:ref:`neomodel_install_labels`) command instead. + + **Deprecations in 5.3** + + - Some standalone methods are moved into the Database() class and will be removed in a future release : + - change_neo4j_password + - clear_neo4j_database + - drop_constraints + - drop_indexes + - remove_all_labels + - install_labels + - install_all_labels + - Additionally, to call these methods with async, use the ones in the AsyncDatabase() _adb_ singleton. + Contents ======== @@ -59,6 +80,8 @@ Contents configuration extending module_documentation + module_documentation_sync + module_documentation_async Indices and tables ================== diff --git a/doc/source/module_documentation.rst b/doc/source/module_documentation.rst index 16a4acf2..364e207e 100644 --- a/doc/source/module_documentation.rst +++ b/doc/source/module_documentation.rst @@ -1,24 +1,6 @@ -===================== -Modules documentation -===================== - -Database -======== -.. module:: neomodel.util -.. autoclass:: neomodel.util.Database - :members: - :undoc-members: - -Core -==== -.. automodule:: neomodel.core - :members: - -.. _semistructurednode_doc: - -``SemiStructuredNode`` ----------------------- -.. autoclass:: neomodel.contrib.SemiStructuredNode +============ +General API +============ Properties ========== @@ -32,43 +14,6 @@ Spatial Properties & Datatypes :members: :show-inheritance: -Relationships -============= -.. automodule:: neomodel.relationship - :members: - :show-inheritance: - -.. automodule:: neomodel.relationship_manager - :members: - :show-inheritance: - -.. automodule:: neomodel.cardinality - :members: - :show-inheritance: - -Paths -===== - -.. automodule:: neomodel.path - :members: - :show-inheritance: - - - - -Match -===== -.. module:: neomodel.match -.. autoclass:: neomodel.match.BaseSet - :members: - :undoc-members: -.. autoclass:: neomodel.match.NodeSet - :members: - :undoc-members: -.. autoclass:: neomodel.match.Traversal - :members: - :undoc-members: - Exceptions ========== @@ -78,16 +23,21 @@ Exceptions :undoc-members: :show-inheritance: + Scripts ======= -.. automodule:: neomodel.scripts.neomodel_install_labels +.. automodule:: neomodel.scripts.neomodel_inspect_database :members: :undoc-members: :show-inheritance: -.. automodule:: neomodel.scripts.neomodel_remove_labels +.. automodule:: neomodel.scripts.neomodel_install_labels :members: :undoc-members: :show-inheritance: +.. automodule:: neomodel.scripts.neomodel_remove_labels + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/doc/source/module_documentation_async.rst b/doc/source/module_documentation_async.rst new file mode 100644 index 00000000..2150b235 --- /dev/null +++ b/doc/source/module_documentation_async.rst @@ -0,0 +1,55 @@ +======================= +Async API Documentation +======================= + +Core +==== +.. automodule:: neomodel.async_.core + :members: + +.. _semistructurednode_doc: + +``AsyncSemiStructuredNode`` +--------------------------- +.. autoclass:: neomodel.contrib.AsyncSemiStructuredNode + +Relationships +============= +.. automodule:: neomodel.async_.relationship + :members: + :show-inheritance: + +.. automodule:: neomodel.async_.relationship_manager + :members: + :show-inheritance: + +.. automodule:: neomodel.async_.cardinality + :members: + :show-inheritance: + +Property Manager +================ +.. automodule:: neomodel.async_.property_manager + :members: + :show-inheritance: + +Paths +===== + +.. automodule:: neomodel.async_.path + :members: + :show-inheritance: + +Match +===== +.. module:: neomodel.async_.match +.. autoclass:: neomodel.async_.match.AsyncBaseSet + :members: + :undoc-members: +.. autoclass:: neomodel.async_.match.AsyncNodeSet + :members: + :undoc-members: +.. autoclass:: neomodel.async_.match.AsyncTraversal + :members: + :undoc-members: + diff --git a/doc/source/module_documentation_sync.rst b/doc/source/module_documentation_sync.rst new file mode 100644 index 00000000..5214a89a --- /dev/null +++ b/doc/source/module_documentation_sync.rst @@ -0,0 +1,55 @@ +====================== +Sync API Documentation +====================== + +Core +==== +.. automodule:: neomodel.sync_.core + :members: + +.. _semistructurednode_doc: + +``SemiStructuredNode`` +--------------------------- +.. autoclass:: neomodel.contrib.SemiStructuredNode + +Relationships +============= +.. automodule:: neomodel.sync_.relationship + :members: + :show-inheritance: + +.. automodule:: neomodel.sync_.relationship_manager + :members: + :show-inheritance: + +.. automodule:: neomodel.sync_.cardinality + :members: + :show-inheritance: + +Property Manager +================ +.. automodule:: neomodel.sync_.property_manager + :members: + :show-inheritance: + +Paths +===== + +.. automodule:: neomodel.sync_.path + :members: + :show-inheritance: + +Match +===== +.. module:: neomodel.sync_.match +.. autoclass:: neomodel.sync_.match.BaseSet + :members: + :undoc-members: +.. autoclass:: neomodel.sync_.match.NodeSet + :members: + :undoc-members: +.. autoclass:: neomodel.sync_.match.Traversal + :members: + :undoc-members: + diff --git a/doc/source/queries.rst b/doc/source/queries.rst index c3e37629..4c77a791 100644 --- a/doc/source/queries.rst +++ b/doc/source/queries.rst @@ -240,3 +240,19 @@ relationships to their relationship models *if such a model exists*. In other wo relationships with data (such as ``PersonLivesInCity`` above) will be instantiated to their respective objects or ``StrucuredRel`` otherwise. Relationships do not "reload" their end-points (unless this is required). + +Async neomodel - Caveats +======================== + +Python does not support async dunder methods. This means that we had to implement some overrides for those. +See the example below:: + + # This will not work as it uses the synchronous __bool__ method + assert await Customer.nodes.filter(prop="value") + + # Do this instead + assert await Customer.nodes.filter(prop="value").check_bool() + assert await Customer.nodes.filter(prop="value").check_nonzero() + + # Note : no changes are needed for sync so this still works : + assert Customer.nodes.filter(prop="value") diff --git a/neomodel/__init__.py b/neomodel/__init__.py index aee40a08..d7d0febb 100644 --- a/neomodel/__init__.py +++ b/neomodel/__init__.py @@ -1,20 +1,25 @@ # pep8: noqa - +from neomodel.async_.cardinality import ( + AsyncOne, + AsyncOneOrMore, + AsyncZeroOrMore, + AsyncZeroOrOne, +) +from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.async_.match import AsyncNodeSet, AsyncTraversal +from neomodel.async_.path import AsyncNeomodelPath +from neomodel.async_.property_manager import AsyncPropertyManager +from neomodel.async_.relationship import AsyncStructuredRel +from neomodel.async_.relationship_manager import ( + AsyncRelationship, + AsyncRelationshipDefinition, + AsyncRelationshipFrom, + AsyncRelationshipManager, + AsyncRelationshipTo, +) from neomodel.exceptions import * -from neomodel.match import EITHER, INCOMING, OUTGOING, NodeSet, Traversal from neomodel.match_q import Q # noqa -from neomodel.relationship_manager import ( - NotConnected, - Relationship, - RelationshipDefinition, - RelationshipFrom, - RelationshipManager, - RelationshipTo, -) - -from .cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne -from .core import * -from .properties import ( +from neomodel.properties import ( AliasProperty, ArrayProperty, BooleanProperty, @@ -30,9 +35,30 @@ StringProperty, UniqueIdProperty, ) -from .relationship import StructuredRel -from .util import change_neo4j_password, clear_neo4j_database -from .path import NeomodelPath +from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne +from neomodel.sync_.core import ( + StructuredNode, + change_neo4j_password, + clear_neo4j_database, + db, + drop_constraints, + drop_indexes, + install_all_labels, + install_labels, + remove_all_labels, +) +from neomodel.sync_.match import NodeSet, Traversal +from neomodel.sync_.path import NeomodelPath +from neomodel.sync_.property_manager import PropertyManager +from neomodel.sync_.relationship import StructuredRel +from neomodel.sync_.relationship_manager import ( + Relationship, + RelationshipDefinition, + RelationshipFrom, + RelationshipManager, + RelationshipTo, +) +from neomodel.util import EITHER, INCOMING, OUTGOING __author__ = "Robin Edwards" __email__ = "robin.ge@gmail.com" diff --git a/neomodel/_async_compat/util.py b/neomodel/_async_compat/util.py new file mode 100644 index 00000000..4868c3ba --- /dev/null +++ b/neomodel/_async_compat/util.py @@ -0,0 +1,9 @@ +import typing as t + + +class AsyncUtil: + is_async_code: t.ClassVar = True + + +class Util: + is_async_code: t.ClassVar = False diff --git a/test/test_contrib/__init__.py b/neomodel/async_/__init__.py similarity index 100% rename from test/test_contrib/__init__.py rename to neomodel/async_/__init__.py diff --git a/neomodel/async_/cardinality.py b/neomodel/async_/cardinality.py new file mode 100644 index 00000000..17101cec --- /dev/null +++ b/neomodel/async_/cardinality.py @@ -0,0 +1,135 @@ +from neomodel.async_.relationship_manager import ( # pylint:disable=unused-import + AsyncRelationshipManager, + AsyncZeroOrMore, +) +from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation + + +class AsyncZeroOrOne(AsyncRelationshipManager): + """A relationship to zero or one node.""" + + description = "zero or one relationship" + + async def single(self): + """ + Return the associated node. + + :return: node + """ + nodes = await super().all() + if len(nodes) == 1: + return nodes[0] + if len(nodes) > 1: + raise CardinalityViolation(self, len(nodes)) + return None + + async def all(self): + node = await self.single() + return [node] if node else [] + + async def connect(self, node, properties=None): + """ + Connect to a node. + + :param node: + :type: StructuredNode + :param properties: relationship properties + :type: dict + :return: True / rel instance + """ + if await super().get_len(): + raise AttemptedCardinalityViolation( + f"Node already has {self} can't connect more" + ) + return await super().connect(node, properties) + + +class AsyncOneOrMore(AsyncRelationshipManager): + """A relationship to zero or more nodes.""" + + description = "one or more relationships" + + async def single(self): + """ + Fetch one of the related nodes + + :return: Node + """ + nodes = await super().all() + if nodes: + return nodes[0] + raise CardinalityViolation(self, "none") + + async def all(self): + """ + Returns all related nodes. + + :return: [node1, node2...] + """ + nodes = await super().all() + if nodes: + return nodes + raise CardinalityViolation(self, "none") + + async def disconnect(self, node): + """ + Disconnect node + :param node: + :return: + """ + if await super().get_len() < 2: + raise AttemptedCardinalityViolation("One or more expected") + return await super().disconnect(node) + + +class AsyncOne(AsyncRelationshipManager): + """ + A relationship to a single node + """ + + description = "one relationship" + + async def single(self): + """ + Return the associated node. + + :return: node + """ + nodes = await super().all() + if nodes: + if len(nodes) == 1: + return nodes[0] + raise CardinalityViolation(self, len(nodes)) + raise CardinalityViolation(self, "none") + + async def all(self): + """ + Return single node in an array + + :return: [node] + """ + return [await self.single()] + + async def disconnect(self, node): + raise AttemptedCardinalityViolation( + "Cardinality one, cannot disconnect use reconnect." + ) + + async def disconnect_all(self): + raise AttemptedCardinalityViolation( + "Cardinality one, cannot disconnect_all use reconnect." + ) + + async def connect(self, node, properties=None): + """ + Connect a node + + :param node: + :param properties: relationship properties + :return: True / rel instance + """ + if not hasattr(self.source, "element_id") or self.source.element_id is None: + raise ValueError("Node has not been saved cannot connect!") + if await super().get_len(): + raise AttemptedCardinalityViolation("Node already has one relationship") + return await super().connect(node, properties) diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py new file mode 100644 index 00000000..21228cb9 --- /dev/null +++ b/neomodel/async_/core.py @@ -0,0 +1,1523 @@ +import logging +import os +import sys +import time +import warnings +from asyncio import iscoroutinefunction +from itertools import combinations +from threading import local +from typing import Optional, Sequence +from urllib.parse import quote, unquote, urlparse + +from neo4j import ( + DEFAULT_DATABASE, + AsyncDriver, + AsyncGraphDatabase, + AsyncResult, + AsyncSession, + AsyncTransaction, + basic_auth, +) +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired +from neo4j.graph import Node, Path, Relationship + +from neomodel import config +from neomodel._async_compat.util import AsyncUtil +from neomodel.async_.property_manager import AsyncPropertyManager +from neomodel.exceptions import ( + ConstraintValidationFailed, + DoesNotExist, + FeatureNotSupported, + NodeClassAlreadyDefined, + NodeClassNotDefined, + RelationshipClassNotDefined, + UniqueProperty, +) +from neomodel.hooks import hooks +from neomodel.properties import Property +from neomodel.util import ( + _get_node_properties, + _UnsavedNode, + classproperty, + deprecated, + version_tag_to_integer, +) + +logger = logging.getLogger(__name__) + +RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" +INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" +CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" +STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" +NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" + + +# make sure the connection url has been set prior to executing the wrapped function +def ensure_connection(func): + """Decorator that ensures a connection is established before executing the decorated function. + + Args: + func (callable): The function to be decorated. + + Returns: + callable: The decorated function. + + """ + + async def wrapper(self, *args, **kwargs): + # Sort out where to find url + if hasattr(self, "db"): + _db = self.db + else: + _db = self + + if not _db.driver: + if hasattr(config, "DATABASE_URL") and config.DATABASE_URL: + await _db.set_connection(url=config.DATABASE_URL) + elif hasattr(config, "DRIVER") and config.DRIVER: + await _db.set_connection(driver=config.DRIVER) + + return await func(self, *args, **kwargs) + + return wrapper + + +class AsyncDatabase(local): + """ + A singleton object via which all operations from neomodel to the Neo4j backend are handled with. + """ + + _NODE_CLASS_REGISTRY = {} + + def __init__(self): + self._active_transaction = None + self.url = None + self.driver = None + self._session = None + self._pid = None + self._database_name = DEFAULT_DATABASE + self.protocol_version = None + self._database_version = None + self._database_edition = None + self.impersonated_user = None + + async def set_connection(self, url: str = None, driver: AsyncDriver = None): + """ + Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. + + Args: + url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname. + When provided, a Neo4j driver instance will be created by neomodel. + + driver (neo4j.Driver): Optionally, a pre-created driver instance. + When provided, neomodel will not create a driver instance but use this one instead. + """ + if driver: + self.driver = driver + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + elif url: + self._parse_driver_from_url(url=url) + + self._pid = os.getpid() + self._active_transaction = None + # Set to default database if it hasn't been set before + if self._database_name is None: + self._database_name = DEFAULT_DATABASE + + # Getting the information about the database version requires a connection to the database + self._database_version = None + self._database_edition = None + await self._update_database_version() + + def _parse_driver_from_url(self, url: str) -> None: + """Parse the driver information from the given URL and initialize the driver. + + Args: + url (str): The URL to parse. + + Raises: + ValueError: If the URL format is not as expected. + + Returns: + None - Sets the driver and database_name as class properties + """ + p_start = url.replace(":", "", 1).find(":") + 2 + p_end = url.rfind("@") + password = url[p_start:p_end] + url = url.replace(password, quote(password)) + parsed_url = urlparse(url) + + valid_schemas = [ + "bolt", + "bolt+s", + "bolt+ssc", + "bolt+routing", + "neo4j", + "neo4j+s", + "neo4j+ssc", + ] + + if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas: + credentials, hostname = parsed_url.netloc.rsplit("@", 1) + username, password = credentials.split(":") + password = unquote(password) + database_name = parsed_url.path.strip("/") + else: + raise ValueError( + f"Expecting url format: bolt://user:password@localhost:7687 got {url}" + ) + + options = { + "auth": basic_auth(username, password), + "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, + "connection_timeout": config.CONNECTION_TIMEOUT, + "keep_alive": config.KEEP_ALIVE, + "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, + "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, + "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, + "resolver": config.RESOLVER, + "user_agent": config.USER_AGENT, + } + + if "+s" not in parsed_url.scheme: + options["encrypted"] = config.ENCRYPTED + options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + + self.driver = AsyncGraphDatabase.driver( + parsed_url.scheme + "://" + hostname, **options + ) + self.url = url + # The database name can be provided through the url or the config + if database_name == "": + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + else: + self._database_name = database_name + + async def close_connection(self): + """ + Closes the currently open driver. + The driver should always be closed at the end of the application's lifecyle. + """ + self._database_version = None + self._database_edition = None + self._database_name = None + await self.driver.close() + self.driver = None + + @property + async def database_version(self): + if self._database_version is None: + await self._update_database_version() + + return self._database_version + + @property + async def database_edition(self): + if self._database_edition is None: + await self._update_database_version() + + return self._database_edition + + @property + def transaction(self): + """ + Returns the current transaction object + """ + return AsyncTransactionProxy(self) + + @property + def write_transaction(self): + return AsyncTransactionProxy(self, access_mode="WRITE") + + @property + def read_transaction(self): + return AsyncTransactionProxy(self, access_mode="READ") + + async def impersonate(self, user: str) -> "ImpersonationHandler": + """All queries executed within this context manager will be executed as impersonated user + + Args: + user (str): User to impersonate + + Returns: + ImpersonationHandler: Context manager to set/unset the user to impersonate + """ + db_edition = await self.database_edition + if db_edition != "enterprise": + raise FeatureNotSupported( + "Impersonation is only available in Neo4j Enterprise edition" + ) + return ImpersonationHandler(self, impersonated_user=user) + + @ensure_connection + async def begin(self, access_mode=None, **parameters): + """ + Begins a new transaction. Raises SystemError if a transaction is already active. + """ + if ( + hasattr(self, "_active_transaction") + and self._active_transaction is not None + ): + raise SystemError("Transaction in progress") + self._session: AsyncSession = self.driver.session( + default_access_mode=access_mode, + database=self._database_name, + impersonated_user=self.impersonated_user, + **parameters, + ) + self._active_transaction: AsyncTransaction = ( + await self._session.begin_transaction() + ) + + @ensure_connection + async def commit(self): + """ + Commits the current transaction and closes its session + + :return: last_bookmarks + """ + try: + await self._active_transaction.commit() + last_bookmarks: Bookmarks = await self._session.last_bookmarks() + finally: + # In case when something went wrong during + # committing changes to the database + # we have to close an active transaction and session. + await self._active_transaction.close() + await self._session.close() + self._active_transaction = None + self._session = None + + return last_bookmarks + + @ensure_connection + async def rollback(self): + """ + Rolls back the current transaction and closes its session + """ + try: + await self._active_transaction.rollback() + finally: + # In case when something went wrong during changes rollback, + # we have to close an active transaction and session + await self._active_transaction.close() + await self._session.close() + self._active_transaction = None + self._session = None + + async def _update_database_version(self): + """ + Updates the database server information when it is required + """ + try: + results = await self.cypher_query( + "CALL dbms.components() yield versions, edition return versions[0], edition" + ) + self._database_version = results[0][0][0] + self._database_edition = results[0][0][1] + except ServiceUnavailable: + # The database server is not running yet + pass + + def _object_resolution(self, object_to_resolve): + """ + Performs in place automatic object resolution on a result + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures and Path objects. Not meant to be called + directly, used primarily by _result_resolution. + + :param object_to_resolve: A result as returned by cypher_query. + :type Any: + + :return: An instantiated object. + """ + # Below is the original comment that came with the code extracted in + # this method. It is not very clear but I decided to keep it just in + # case + # + # + # For some reason, while the type of `a_result_attribute[1]` + # as reported by the neo4j driver is `Node` for Node-type data + # retrieved from the database. + # When the retrieved data are Relationship-Type, + # the returned type is `abc.[REL_LABEL]` which is however + # a descendant of Relationship. + # Consequently, the type checking was changed for both + # Node, Relationship objects + if isinstance(object_to_resolve, Node): + return self._NODE_CLASS_REGISTRY[ + frozenset(object_to_resolve.labels) + ].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Relationship): + rel_type = frozenset([object_to_resolve.type]) + return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Path): + from neomodel.async_.path import AsyncNeomodelPath + + return AsyncNeomodelPath(object_to_resolve) + + if isinstance(object_to_resolve, list): + return self._result_resolution([object_to_resolve]) + + return object_to_resolve + + def _result_resolution(self, result_list): + """ + Performs in place automatic object resolution on a set of results + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures. Not meant to be called directly, + used primarily by cypher_query. + + :param result_list: A list of results as returned by cypher_query. + :type list: + + :return: A list of instantiated objects. + """ + + # Object resolution occurs in-place + for a_result_item in enumerate(result_list): + for a_result_attribute in enumerate(a_result_item[1]): + try: + # Primitive types should remain primitive types, + # Nodes to be resolved to native objects + resolved_object = a_result_attribute[1] + + resolved_object = self._object_resolution(resolved_object) + + result_list[a_result_item[0]][ + a_result_attribute[0] + ] = resolved_object + + except KeyError as exc: + # Not being able to match the label set of a node with a known object results + # in a KeyError in the internal dictionary used for resolution. If it is impossible + # to match, then raise an exception with more details about the error. + if isinstance(a_result_attribute[1], Node): + raise NodeClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + if isinstance(a_result_attribute[1], Relationship): + raise RelationshipClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + return result_list + + @ensure_connection + async def cypher_query( + self, + query, + params=None, + handle_unique=True, + retry_on_session_expire=False, + resolve_objects=False, + ): + """ + Runs a query on the database and returns a list of results and their headers. + + :param query: A CYPHER query + :type: str + :param params: Dictionary of parameters + :type: dict + :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors + :type: bool + :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired. + If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance. + :type: bool + :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically + :type: bool + + :return: A tuple containing a list of results and a tuple of headers. + """ + + if self._active_transaction: + # Use current session is a transaction is currently active + results, meta = await self._run_cypher_query( + self._active_transaction, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + else: + # Otherwise create a new session in a with to dispose of it after it has been run + async with self.driver.session( + database=self._database_name, impersonated_user=self.impersonated_user + ) as session: + results, meta = await self._run_cypher_query( + session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + + return results, meta + + async def _run_cypher_query( + self, + session: AsyncSession, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ): + try: + # Retrieve the data + start = time.time() + response: AsyncResult = await session.run(query, params) + results, meta = [list(r.values()) async for r in response], response.keys() + end = time.time() + + if resolve_objects: + # Do any automatic resolution required + results = self._result_resolution(results) + + except ClientError as e: + if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": + if "already exists with label" in e.message and handle_unique: + raise UniqueProperty(e.message) from e + + raise ConstraintValidationFailed(e.message) from e + exc_info = sys.exc_info() + raise exc_info[1].with_traceback(exc_info[2]) + except SessionExpired: + if retry_on_session_expire: + await self.set_connection(url=self.url) + return await self.cypher_query( + query=query, + params=params, + handle_unique=handle_unique, + retry_on_session_expire=False, + ) + raise + + tte = end - start + if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float( + os.environ.get("NEOMODEL_SLOW_QUERIES", 0) + ): + logger.debug( + "query: " + + query + + "\nparams: " + + repr(params) + + f"\ntook: {tte:.2g}s\n" + ) + + return results, meta + + async def get_id_method(self) -> str: + db_version = await self.database_version + if db_version.startswith("4"): + return "id" + else: + return "elementId" + + async def parse_element_id(self, element_id: str): + db_version = await self.database_version + return int(element_id) if db_version.startswith("4") else element_id + + async def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: + """Returns all indexes existing in the database + + Arguments: + exclude_token_lookup[bool]: Exclude automatically create token lookup indexes + + Returns: + Sequence[dict]: List of dictionaries, each entry being an index definition + """ + indexes, meta_indexes = await self.cypher_query("SHOW INDEXES") + indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] + + if exclude_token_lookup: + indexes_as_dict = [ + obj for obj in indexes_as_dict if obj["type"] != "LOOKUP" + ] + + return indexes_as_dict + + async def list_constraints(self) -> Sequence[dict]: + """Returns all constraints existing in the database + + Returns: + Sequence[dict]: List of dictionaries, each entry being a constraint definition + """ + constraints, meta_constraints = await self.cypher_query("SHOW CONSTRAINTS") + constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] + + return constraints_as_dict + + @ensure_connection + async def version_is_higher_than(self, version_tag: str) -> bool: + """Returns true if the database version is higher or equal to a given tag + + Args: + version_tag (str): The version to compare against + + Returns: + bool: True if the database version is higher or equal to the given version + """ + db_version = await self.database_version + return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag) + + @ensure_connection + async def edition_is_enterprise(self) -> bool: + """Returns true if the database edition is enterprise + + Returns: + bool: True if the database edition is enterprise + """ + edition = await self.database_edition + return edition == "enterprise" + + async def change_neo4j_password(self, user, new_password): + await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") + + async def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): + await self.cypher_query( + """ + MATCH (a) + CALL { WITH a DETACH DELETE a } + IN TRANSACTIONS OF 5000 rows + """ + ) + if clear_constraints: + await drop_constraints() + if clear_indexes: + await drop_indexes() + + async def drop_constraints(self, quiet=True, stdout=None): + """ + Discover and drop all constraints. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + results, meta = await self.cypher_query("SHOW CONSTRAINTS") + + results_as_dict = [dict(zip(meta, row)) for row in results] + for constraint in results_as_dict: + await self.cypher_query("DROP CONSTRAINT " + constraint["name"]) + if not quiet: + stdout.write( + ( + " - Dropping unique constraint and index" + f" on label {constraint['labelsOrTypes'][0]}" + f" with property {constraint['properties'][0]}.\n" + ) + ) + if not quiet: + stdout.write("\n") + + async def drop_indexes(self, quiet=True, stdout=None): + """ + Discover and drop all indexes, except the automatically created token lookup indexes. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + indexes = await self.list_indexes(exclude_token_lookup=True) + for index in indexes: + await self.cypher_query("DROP INDEX " + index["name"]) + if not quiet: + stdout.write( + f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' + ) + if not quiet: + stdout.write("\n") + + async def remove_all_labels(self, stdout=None): + """ + Calls functions for dropping constraints and indexes. + + :param stdout: output stream + :return: None + """ + + if not stdout: + stdout = sys.stdout + + stdout.write("Dropping constraints...\n") + await self.drop_constraints(quiet=False, stdout=stdout) + + stdout.write("Dropping indexes...\n") + await self.drop_indexes(quiet=False, stdout=stdout) + + async def install_all_labels(self, stdout=None): + """ + Discover all subclasses of StructuredNode in your application and execute install_labels on each. + Note: code must be loaded (imported) in order for a class to be discovered. + + :param stdout: output stream + :return: None + """ + + if not stdout or stdout is None: + stdout = sys.stdout + + def subsub(cls): # recursively return all subclasses + subclasses = cls.__subclasses__() + if not subclasses: # base case: no more subclasses + return [] + return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] + + stdout.write("Setting up indexes and constraints...\n\n") + + i = 0 + for cls in subsub(AsyncStructuredNode): + stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") + await install_labels(cls, quiet=False, stdout=stdout) + i += 1 + + if i: + stdout.write("\n") + + stdout.write(f"Finished {i} classes.\n") + + async def install_labels(self, cls, quiet=True, stdout=None): + """ + Setup labels with indexes and constraints for a given class + + :param cls: StructuredNode class + :type: class + :param quiet: (default true) enable standard output + :param stdout: stdout stream + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + if not hasattr(cls, "__label__"): + if not quiet: + stdout.write( + f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" + ) + return + + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + await self._install_node(cls, name, property, quiet, stdout) + + for _, relationship in cls.defined_properties( + aliases=False, rels=True, properties=False + ).items(): + await self._install_relationship(cls, relationship, quiet, stdout) + + async def _create_node_index(self, label: str, property_name: str, stdout): + try: + await self.cypher_query( + f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + async def _create_node_constraint(self, label: str, property_name: str, stdout): + try: + await self.cypher_query( + f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} + FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + async def _create_relationship_index( + self, relationship_type: str, property_name: str, stdout + ): + try: + await self.cypher_query( + f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + async def _create_relationship_constraint( + self, relationship_type: str, property_name: str, stdout + ): + if await self.version_is_higher_than("5.7"): + try: + await self.cypher_query( + f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} + FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + else: + raise FeatureNotSupported( + f"Unique indexes on relationships are not supported in Neo4j version {await self.database_version}. Please upgrade to Neo4j 5.7 or higher." + ) + + async def _install_node(self, cls, name, property, quiet, stdout): + # Create indexes and constraints for node property + db_property = property.db_property or name + if property.index: + if not quiet: + stdout.write( + f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + await self._create_node_index( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + await self._create_node_constraint( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + async def _install_relationship(self, cls, relationship, quiet, stdout): + # Create indexes and constraints for relationship property + relationship_cls = relationship.definition["model"] + if relationship_cls is not None: + relationship_type = relationship.definition["relation_type"] + for prop_name, property in relationship_cls.defined_properties( + aliases=False, rels=False + ).items(): + db_property = property.db_property or prop_name + if property.index: + if not quiet: + stdout.write( + f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + await self._create_relationship_index( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + await self._create_relationship_constraint( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + + +# Create a singleton instance of the database object +adb = AsyncDatabase() + + +# Deprecated methods +async def change_neo4j_password(db: AsyncDatabase, user, new_password): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.change_neo4j_password(user, new_password) instead. + This direct call will be removed in an upcoming version. + """ + ) + await db.change_neo4j_password(user, new_password) + + +async def clear_neo4j_database( + db: AsyncDatabase, clear_constraints=False, clear_indexes=False +): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead. + This direct call will be removed in an upcoming version. + """ + ) + await db.clear_neo4j_database(clear_constraints, clear_indexes) + + +async def drop_constraints(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.drop_constraints(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.drop_constraints(quiet, stdout) + + +async def drop_indexes(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.drop_indexes(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.drop_indexes(quiet, stdout) + + +async def remove_all_labels(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.remove_all_labels(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.remove_all_labels(stdout) + + +async def install_labels(cls, quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.install_labels(cls, quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.install_labels(cls, quiet, stdout) + + +async def install_all_labels(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, adb for async). + Please use adb.install_all_labels(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + await adb.install_all_labels(stdout) + + +class AsyncTransactionProxy: + bookmarks: Optional[Bookmarks] = None + + def __init__(self, db: AsyncDatabase, access_mode=None): + self.db = db + self.access_mode = access_mode + + @ensure_connection + async def __aenter__(self): + print("aenter called") + await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) + self.bookmarks = None + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + print("aexit called") + if exc_value: + await self.db.rollback() + + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) + + if not exc_value: + self.last_bookmark = await self.db.commit() + + def __call__(self, func): + if AsyncUtil.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + async def wrapper(*args, **kwargs): + async with self: + print("call called") + return await func(*args, **kwargs) + + return wrapper + + @property + def with_bookmark(self): + return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) + + +class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy): + def __call__(self, func): + if AsyncUtil.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + async def wrapper(*args, **kwargs): + self.bookmarks = kwargs.pop("bookmarks", None) + + async with self: + result = await func(*args, **kwargs) + self.last_bookmark = None + + return result, self.last_bookmark + + return wrapper + + +class ImpersonationHandler: + def __init__(self, db: AsyncDatabase, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self): + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + +class NodeMeta(type): + def __new__(mcs, name, bases, namespace): + namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) + cls = super().__new__(mcs, name, bases, namespace) + cls.DoesNotExist._model_class = cls + + if hasattr(cls, "__abstract_node__"): + delattr(cls, "__abstract_node__") + else: + if "deleted" in namespace: + raise ValueError( + "Property name 'deleted' is not allowed as it conflicts with neomodel internals." + ) + elif "id" in namespace: + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif "element_id" in namespace: + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + for key, value in ( + (x, y) for x, y in namespace.items() if isinstance(y, Property) + ): + value.name, value.owner = key, cls + if hasattr(value, "setup") and callable(value.setup): + value.setup() + + # cache various groups of properies + cls.__required_properties__ = tuple( + name + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + if property.required or property.unique_index + ) + cls.__all_properties__ = tuple( + cls.defined_properties(aliases=False, rels=False).items() + ) + cls.__all_aliases__ = tuple( + cls.defined_properties(properties=False, rels=False).items() + ) + cls.__all_relationships__ = tuple( + cls.defined_properties(aliases=False, properties=False).items() + ) + + cls.__label__ = namespace.get("__label__", name) + cls.__optional_labels__ = namespace.get("__optional_labels__", []) + + build_class_registry(cls) + + return cls + + +def build_class_registry(cls): + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) + + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if label_set not in adb._NODE_CLASS_REGISTRY: + adb._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined(cls, adb._NODE_CLASS_REGISTRY) + + +NodeBase = NodeMeta("NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True}) + + +class AsyncStructuredNode(NodeBase): + """ + Base class for all node definitions to inherit from. + + If you want to create your own abstract classes set: + __abstract_node__ = True + """ + + # static properties + + __abstract_node__ = True + + # magic methods + + def __init__(self, *args, **kwargs): + if "deleted" in kwargs: + raise ValueError("deleted property is reserved for neomodel") + + for key, val in self.__all_relationships__: + self.__dict__[key] = val.build_manager(self, key) + + super().__init__(*args, **kwargs) + + def __eq__(self, other): + if not isinstance(other, (AsyncStructuredNode,)): + return False + if hasattr(self, "element_id") and hasattr(other, "element_id"): + return self.element_id == other.element_id + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return f"<{self.__class__.__name__}: {self}>" + + def __str__(self): + return repr(self.__properties__) + + # dynamic properties + + @classproperty + def nodes(cls): + """ + Returns a NodeSet object representing all nodes of the classes label + :return: NodeSet + :rtype: NodeSet + """ + from neomodel.async_.match import AsyncNodeSet + + return AsyncNodeSet(cls) + + @property + def element_id(self): + if hasattr(self, "element_id_property"): + return self.element_id_property + return None + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self): + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + # methods + + @classmethod + async def _build_merge_query( + cls, merge_params, update_existing=False, lazy=False, relationship=None + ): + """ + Get a tuple of a CYPHER query and a params dict for the specified MERGE query. + + :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". + :type merge_params: list of dict + :param update_existing: True to update properties of existing nodes, default False to keep existing values. + :type update_existing: bool + :rtype: tuple + """ + query_params = dict(merge_params=merge_params) + n_merge_labels = ":".join(cls.inherited_labels()) + n_merge_prm = ", ".join( + ( + f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + for p in cls.__required_properties__ + ) + ) + n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" + if relationship is None: + # create "simple" unwind query + query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " + else: + # validate relationship + if not isinstance(relationship.source, AsyncStructuredNode): + raise ValueError( + f"relationship source [{repr(relationship.source)}] is not a StructuredNode" + ) + relation_type = relationship.definition.get("relation_type") + if not relation_type: + raise ValueError( + "No relation_type is specified on provided relationship" + ) + + from neomodel.async_.match import _rel_helper + + query_params["source_id"] = await adb.parse_element_id( + relationship.source.element_id + ) + query = f"MATCH (source:{relationship.source.__label__}) WHERE {await adb.get_id_method()}(source) = $source_id\n " + query += "WITH source\n UNWIND $merge_params as params \n " + query += "MERGE " + query += _rel_helper( + lhs="source", + rhs=n_merge, + ident=None, + relation_type=relation_type, + direction=relationship.definition["direction"], + ) + + query += "ON CREATE SET n = params.create\n " + # if update_existing, write properties on match as well + if update_existing is True: + query += "ON MATCH SET n += params.update\n" + + # close query + if lazy: + query += f"RETURN {await adb.get_id_method()}(n)" + else: + query += "RETURN n" + + return query, query_params + + @classmethod + async def create(cls, *props, **kwargs): + """ + Call to CREATE with parameters map. A new instance will be created and saved. + + :param props: dict of properties to create the nodes. + :type props: tuple + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type: bool + :rtype: list + """ + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + lazy = kwargs.get("lazy", False) + # create mapped query + query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" + + # close query + if lazy: + query += f" RETURN {await adb.get_id_method()}(n)" + else: + query += " RETURN n" + + results = [] + for item in [ + cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props + ]: + node, _ = await adb.cypher_query(query, {"create_params": item}) + results.extend(node[0]) + + nodes = [cls.inflate(node) for node in results] + + if not lazy and hasattr(cls, "post_create"): + for node in nodes: + node.post_create() + + return nodes + + @classmethod + async def create_or_update(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, + this is an atomic operation. If an instance already exists all optional properties specified will be updated. + + Note that the post_create hook isn't called after create_or_update + + :param props: List of dict arguments to get or create the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query, make sure to update only explicitly specified properties + create_or_update_params = [] + for specified, deflated in [ + (p, cls.deflate(p, skip_empty=True)) for p in props + ]: + create_or_update_params.append( + { + "create": deflated, + "update": dict( + (k, v) for k, v in deflated.items() if k in specified + ), + } + ) + query, params = await cls._build_merge_query( + create_or_update_params, + update_existing=True, + relationship=relationship, + lazy=lazy, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = await adb.cypher_query(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + async def cypher(self, query, params=None): + """ + Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. + + :param query: cypher query string + :type: string + :param params: query parameters + :type: dict + :return: list containing query results + :rtype: list + """ + self._pre_action_check("cypher") + params = params or {} + element_id = await adb.parse_element_id(self.element_id) + params.update({"self": element_id}) + return await adb.cypher_query(query, params) + + @hooks + async def delete(self): + """ + Delete a node and its relationships + + :return: True + """ + self._pre_action_check("delete") + await self.cypher( + f"MATCH (self) WHERE {await adb.get_id_method()}(self)=$self DETACH DELETE self" + ) + delattr(self, "element_id_property") + self.deleted = True + return True + + @classmethod + async def get_or_create(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, + this is an atomic operation. + Parameters must contain all required properties, any non required properties with defaults will be generated. + + Note that the post_create hook isn't called after get_or_create + + :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create + the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query + get_or_create_params = [ + {"create": cls.deflate(p, skip_empty=True)} for p in props + ] + query, params = await cls._build_merge_query( + get_or_create_params, relationship=relationship, lazy=lazy + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = await adb.cypher_query(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + @classmethod + def inflate(cls, node): + """ + Inflate a raw neo4j_driver node to a neomodel node + :param node: + :return: node object + """ + # support lazy loading + if isinstance(node, str) or isinstance(node, int): + snode = cls() + snode.element_id_property = node + else: + node_properties = _get_node_properties(node) + props = {} + for key, prop in cls.__all_properties__: + # map property name from database to object property + db_property = prop.db_property or key + + if db_property in node_properties: + props[key] = prop.inflate(node_properties[db_property], node) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + + snode = cls(**props) + snode.element_id_property = node.element_id + + return snode + + @classmethod + def inherited_labels(cls): + """ + Return list of labels from nodes class hierarchy. + + :return: list + """ + return [ + scls.__label__ + for scls in cls.mro() + if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") + ] + + @classmethod + def inherited_optional_labels(cls): + """ + Return list of optional labels from nodes class hierarchy. + + :return: list + :rtype: list + """ + return [ + label + for scls in cls.mro() + for label in getattr(scls, "__optional_labels__", []) + if not hasattr(scls, "__abstract_node__") + ] + + async def labels(self): + """ + Returns list of labels tied to the node from neo4j. + + :return: list of labels + :rtype: list + """ + self._pre_action_check("labels") + result = await self.cypher( + f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self " "RETURN labels(n)" + ) + return result[0][0][0] + + def _pre_action_check(self, action): + if hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on deleted node" + ) + if not hasattr(self, "element_id"): + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on unsaved node" + ) + + async def refresh(self): + """ + Reload the node from neo4j + """ + self._pre_action_check("refresh") + if hasattr(self, "element_id"): + results = await self.cypher( + f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self RETURN n" + ) + request = results[0] + if not request or not request[0]: + raise self.__class__.DoesNotExist("Can't refresh non existent node") + node = self.inflate(request[0][0]) + for key, val in node.__properties__.items(): + setattr(self, key, val) + else: + raise ValueError("Can't refresh unsaved node") + + @hooks + async def save(self): + """ + Save the node to neo4j or raise an exception + + :return: the node instance + """ + + # create or update instance node + if hasattr(self, "element_id_property"): + # update + params = self.deflate(self.__properties__, self) + query = f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self\n" + + if params: + query += "SET " + query += ",\n".join([f"n.{key} = ${key}" for key in params]) + query += "\n" + if self.inherited_labels(): + query += "\n".join( + [f"SET n:`{label}`" for label in self.inherited_labels()] + ) + await self.cypher(query, params) + elif hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.save() attempted on deleted node" + ) + else: # create + result = await self.create(self.__properties__) + created_node = result[0] + self.element_id_property = created_node.element_id + return self diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py new file mode 100644 index 00000000..e7e82adb --- /dev/null +++ b/neomodel/async_/match.py @@ -0,0 +1,1068 @@ +import inspect +import re +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from neomodel.async_.core import AsyncStructuredNode, adb +from neomodel.exceptions import MultipleNodesReturned +from neomodel.match_q import Q, QBase +from neomodel.properties import AliasProperty +from neomodel.util import INCOMING, OUTGOING + + +def _rel_helper( + lhs, + rhs, + ident=None, + relation_type=None, + direction=None, + relation_properties=None, + **kwargs, # NOSONAR +): + """ + Generate a relationship matching string, with specified parameters. + Examples: + relation_direction = OUTGOING: (lhs)-[relation_ident:relation_type]->(rhs) + relation_direction = INCOMING: (lhs)<-[relation_ident:relation_type]-(rhs) + relation_direction = EITHER: (lhs)-[relation_ident:relation_type]-(rhs) + + :param lhs: The left hand statement. + :type lhs: str + :param rhs: The right hand statement. + :type rhs: str + :param ident: A specific identity to name the relationship, or None. + :type ident: str + :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel. + :type relation_type: str + :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING. + :param relation_properties: dictionary of relationship properties to match + :returns: string + """ + rel_props = "" + + if relation_properties: + rel_props_str = ", ".join( + (f"{key}: {value}" for key, value in relation_properties.items()) + ) + rel_props = f" {{{rel_props_str}}}" + + rel_def = "" + # relation_type is unspecified + if relation_type is None: + rel_def = "" + # all("*" wildcard) relation_type + elif relation_type == "*": + rel_def = "[*]" + else: + # explicit relation_type + rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]" + + stmt = "" + if direction == OUTGOING: + stmt = f"-{rel_def}->" + elif direction == INCOMING: + stmt = f"<-{rel_def}-" + else: + stmt = f"-{rel_def}-" + + # Make sure not to add parenthesis when they are already present + if lhs[-1] != ")": + lhs = f"({lhs})" + if rhs[-1] != ")": + rhs = f"({rhs})" + + return f"{lhs}{stmt}{rhs}" + + +def _rel_merge_helper( + lhs, + rhs, + ident="neomodelident", + relation_type=None, + direction=None, + relation_properties=None, + **kwargs, # NOSONAR +): + """ + Generate a relationship merging string, with specified parameters. + Examples: + relation_direction = OUTGOING: (lhs)-[relation_ident:relation_type]->(rhs) + relation_direction = INCOMING: (lhs)<-[relation_ident:relation_type]-(rhs) + relation_direction = EITHER: (lhs)-[relation_ident:relation_type]-(rhs) + + :param lhs: The left hand statement. + :type lhs: str + :param rhs: The right hand statement. + :type rhs: str + :param ident: A specific identity to name the relationship, or None. + :type ident: str + :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel. + :type relation_type: str + :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING. + :param relation_properties: dictionary of relationship properties to merge + :returns: string + """ + + if direction == OUTGOING: + stmt = "-{0}->" + elif direction == INCOMING: + stmt = "<-{0}-" + else: + stmt = "-{0}-" + + rel_props = "" + rel_none_props = "" + + if relation_properties: + rel_props_str = ", ".join( + ( + f"{key}: {value}" + for key, value in relation_properties.items() + if value is not None + ) + ) + rel_props = f" {{{rel_props_str}}}" + if None in relation_properties.values(): + rel_prop_val_str = ", ".join( + ( + f"{ident}.{key}=${key!s}" + for key, value in relation_properties.items() + if value is None + ) + ) + rel_none_props = ( + f" ON CREATE SET {rel_prop_val_str} ON MATCH SET {rel_prop_val_str}" + ) + # relation_type is unspecified + if relation_type is None: + stmt = stmt.format("") + # all("*" wildcard) relation_type + elif relation_type == "*": + stmt = stmt.format("[*]") + else: + # explicit relation_type + stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]") + + return f"({lhs}){stmt}({rhs}){rel_none_props}" + + +# special operators +_SPECIAL_OPERATOR_IN = "IN" +_SPECIAL_OPERATOR_INSENSITIVE = "(?i)" +_SPECIAL_OPERATOR_ISNULL = "IS NULL" +_SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL" +_SPECIAL_OPERATOR_REGEX = "=~" + +_UNARY_OPERATORS = (_SPECIAL_OPERATOR_ISNULL, _SPECIAL_OPERATOR_ISNOTNULL) + +_REGEX_INSESITIVE = _SPECIAL_OPERATOR_INSENSITIVE + "{}" +_REGEX_CONTAINS = ".*{}.*" +_REGEX_STARTSWITH = "{}.*" +_REGEX_ENDSWITH = ".*{}" + +# regex operations that require escaping +_STRING_REGEX_OPERATOR_TABLE = { + "iexact": _REGEX_INSESITIVE, + "contains": _REGEX_CONTAINS, + "icontains": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_CONTAINS, + "startswith": _REGEX_STARTSWITH, + "istartswith": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_STARTSWITH, + "endswith": _REGEX_ENDSWITH, + "iendswith": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_ENDSWITH, +} +# regex operations that do not require escaping +_REGEX_OPERATOR_TABLE = { + "iregex": _REGEX_INSESITIVE, +} +# list all regex operations, these will require formatting of the value +_REGEX_OPERATOR_TABLE.update(_STRING_REGEX_OPERATOR_TABLE) + +# list all supported operators +OPERATOR_TABLE = { + "lt": "<", + "gt": ">", + "lte": "<=", + "gte": ">=", + "ne": "<>", + "in": _SPECIAL_OPERATOR_IN, + "isnull": _SPECIAL_OPERATOR_ISNULL, + "regex": _SPECIAL_OPERATOR_REGEX, + "exact": "=", +} +# add all regex operators +OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE) + + +def install_traversals(cls, node_set): + """ + For a StructuredNode class install Traversal objects for each + relationship definition on a NodeSet instance + """ + rels = cls.defined_properties(rels=True, aliases=False, properties=False) + + for key in rels.keys(): + if hasattr(node_set, key): + raise ValueError(f"Cannot install traversal '{key}' exists on NodeSet") + + rel = getattr(cls, key) + rel.lookup_node_class() + + traversal = AsyncTraversal(source=node_set, name=key, definition=rel.definition) + setattr(node_set, key, traversal) + + +def process_filter_args(cls, kwargs): + """ + loop through properties in filter parameters check they match class definition + deflate them and convert into something easy to generate cypher from + """ + + output = {} + + for key, value in kwargs.items(): + if "__" in key: + prop, operator = key.rsplit("__") + operator = OPERATOR_TABLE[operator] + else: + prop = key + operator = "=" + + if prop not in cls.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + + property_obj = getattr(cls, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + deflated_value = getattr(cls, prop).deflate(value) + else: + operator, deflated_value = transform_operator_to_filter( + operator=operator, + filter_key=key, + filter_value=value, + property_obj=property_obj, + ) + + # map property to correct property name in the database + db_property = cls.defined_properties(rels=False)[prop].db_property or prop + + output[db_property] = (operator, deflated_value) + + return output + + +def transform_operator_to_filter(operator, filter_key, filter_value, property_obj): + # handle special operators + if operator == _SPECIAL_OPERATOR_IN: + if not isinstance(filter_value, tuple) and not isinstance(filter_value, list): + raise ValueError( + f"Value must be a tuple or list for IN operation {filter_key}={filter_value}" + ) + deflated_value = [property_obj.deflate(v) for v in filter_value] + elif operator == _SPECIAL_OPERATOR_ISNULL: + if not isinstance(filter_value, bool): + raise ValueError( + f"Value must be a bool for isnull operation on {filter_key}" + ) + operator = "IS NULL" if filter_value else "IS NOT NULL" + deflated_value = None + elif operator in _REGEX_OPERATOR_TABLE.values(): + deflated_value = property_obj.deflate(filter_value) + if not isinstance(deflated_value, str): + raise ValueError(f"Must be a string value for {filter_key}") + if operator in _STRING_REGEX_OPERATOR_TABLE.values(): + deflated_value = re.escape(deflated_value) + deflated_value = operator.format(deflated_value) + operator = _SPECIAL_OPERATOR_REGEX + else: + deflated_value = property_obj.deflate(filter_value) + + return operator, deflated_value + + +def process_has_args(cls, kwargs): + """ + loop through has parameters check they correspond to class rels defined + """ + rel_definitions = cls.defined_properties(properties=False, rels=True, aliases=False) + + match, dont_match = {}, {} + + for key, value in kwargs.items(): + if key not in rel_definitions: + raise ValueError(f"No such relation {key} defined on a {cls.__name__}") + + rhs_ident = key + + rel_definitions[key].lookup_node_class() + + if value is True: + match[rhs_ident] = rel_definitions[key].definition + elif value is False: + dont_match[rhs_ident] = rel_definitions[key].definition + elif isinstance(value, AsyncNodeSet): + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("Expecting True / False / NodeSet got: " + repr(value)) + + return match, dont_match + + +class QueryAST: + match: Optional[list] + optional_match: Optional[list] + where: Optional[list] + with_clause: Optional[str] + return_clause: Optional[str] + order_by: Optional[str] + skip: Optional[int] + limit: Optional[int] + result_class: Optional[type] + lookup: Optional[str] + additional_return: Optional[list] + is_count: Optional[bool] + + def __init__( + self, + match: Optional[list] = None, + optional_match: Optional[list] = None, + where: Optional[list] = None, + with_clause: Optional[str] = None, + return_clause: Optional[str] = None, + order_by: Optional[str] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + result_class: Optional[type] = None, + lookup: Optional[str] = None, + additional_return: Optional[list] = None, + is_count: Optional[bool] = False, + ): + self.match = match if match else [] + self.optional_match = optional_match if optional_match else [] + self.where = where if where else [] + self.with_clause = with_clause + self.return_clause = return_clause + self.order_by = order_by + self.skip = skip + self.limit = limit + self.result_class = result_class + self.lookup = lookup + self.additional_return = additional_return if additional_return else [] + self.is_count = is_count + + +class AsyncQueryBuilder: + def __init__(self, node_set): + self.node_set = node_set + self._ast = QueryAST() + self._query_params = {} + self._place_holder_registry = {} + self._ident_count = 0 + self._node_counters = defaultdict(int) + + async def build_ast(self): + if hasattr(self.node_set, "relations_to_fetch"): + for relation in self.node_set.relations_to_fetch: + self.build_traversal_from_path(relation, self.node_set.source) + + await self.build_source(self.node_set) + + if hasattr(self.node_set, "skip"): + self._ast.skip = self.node_set.skip + if hasattr(self.node_set, "limit"): + self._ast.limit = self.node_set.limit + + return self + + async def build_source(self, source): + if isinstance(source, AsyncTraversal): + return await self.build_traversal(source) + if isinstance(source, AsyncNodeSet): + if inspect.isclass(source.source) and issubclass( + source.source, AsyncStructuredNode + ): + ident = self.build_label(source.source.__label__.lower(), source.source) + else: + ident = await self.build_source(source.source) + + self.build_additional_match(ident, source) + + if hasattr(source, "order_by_elements"): + self.build_order_by(ident, source) + + if source.filters or source.q_filters: + self.build_where_stmt( + ident, + source.filters, + source.q_filters, + source_class=source.source_class, + ) + + return ident + if isinstance(source, AsyncStructuredNode): + return await self.build_node(source) + raise ValueError("Unknown source type " + repr(source)) + + def create_ident(self): + self._ident_count += 1 + return "r" + str(self._ident_count) + + def build_order_by(self, ident, source): + if "?" in source.order_by_elements: + self._ast.with_clause = f"{ident}, rand() as r" + self._ast.order_by = "r" + else: + self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements] + + async def build_traversal(self, traversal): + """ + traverse a relationship from a node to a set of nodes + """ + # build source + rhs_label = ":" + traversal.target_class.__label__ + + # build source + rel_ident = self.create_ident() + lhs_ident = await self.build_source(traversal.source) + traversal_ident = f"{traversal.name}_{rel_ident}" + rhs_ident = traversal_ident + rhs_label + self._ast.return_clause = traversal_ident + self._ast.result_class = traversal.target_class + + stmt = _rel_helper( + lhs=lhs_ident, + rhs=rhs_ident, + ident=rel_ident, + **traversal.definition, + ) + self._ast.match.append(stmt) + + if traversal.filters: + self.build_where_stmt(rel_ident, traversal.filters) + + return traversal_ident + + def _additional_return(self, name): + if name not in self._ast.additional_return and name != self._ast.return_clause: + self._ast.additional_return.append(name) + + def build_traversal_from_path(self, relation: dict, source_class) -> str: + path: str = relation["path"] + stmt: str = "" + source_class_iterator = source_class + for index, part in enumerate(path.split("__")): + relationship = getattr(source_class_iterator, part) + # build source + if "node_class" not in relationship.definition: + relationship.lookup_node_class() + rhs_label = relationship.definition["node_class"].__label__ + rel_reference = f'{relationship.definition["node_class"]}_{part}' + self._node_counters[rel_reference] += 1 + rhs_name = ( + f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}" + ) + rhs_ident = f"{rhs_name}:{rhs_label}" + self._additional_return(rhs_name) + if not stmt: + lhs_label = source_class_iterator.__label__ + lhs_name = lhs_label.lower() + lhs_ident = f"{lhs_name}:{lhs_label}" + if not index: + # This is the first one, we make sure that 'return' + # contains the primary node so _contains() works + # as usual + self._ast.return_clause = lhs_name + else: + self._additional_return(lhs_name) + else: + lhs_ident = stmt + + rel_ident = self.create_ident() + self._additional_return(rel_ident) + stmt = _rel_helper( + lhs=lhs_ident, + rhs=rhs_ident, + ident=rel_ident, + direction=relationship.definition["direction"], + relation_type=relationship.definition["relation_type"], + ) + source_class_iterator = relationship.definition["node_class"] + + if relation.get("optional"): + self._ast.optional_match.append(stmt) + else: + self._ast.match.append(stmt) + return rhs_name + + async def build_node(self, node): + ident = node.__class__.__name__.lower() + place_holder = self._register_place_holder(ident) + + # Hack to emulate START to lookup a node by id + _node_lookup = f"MATCH ({ident}) WHERE {await adb.get_id_method()}({ident})=${place_holder} WITH {ident}" + self._ast.lookup = _node_lookup + + self._query_params[place_holder] = await adb.parse_element_id(node.element_id) + + self._ast.return_clause = ident + self._ast.result_class = node.__class__ + return ident + + def build_label(self, ident, cls): + """ + match nodes by a label + """ + ident_w_label = ident + ":" + cls.__label__ + + if not self._ast.return_clause and ( + not self._ast.additional_return or ident not in self._ast.additional_return + ): + self._ast.match.append(f"({ident_w_label})") + self._ast.return_clause = ident + self._ast.result_class = cls + return ident + + def build_additional_match(self, ident, node_set): + """ + handle additional matches supplied by 'has()' calls + """ + source_ident = ident + + for _, value in node_set.must_match.items(): + if isinstance(value, dict): + label = ":" + value["node_class"].__label__ + stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **value) + self._ast.where.append(stmt) + else: + raise ValueError("Expecting dict got: " + repr(value)) + + for _, val in node_set.dont_match.items(): + if isinstance(val, dict): + label = ":" + val["node_class"].__label__ + stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **val) + self._ast.where.append("NOT " + stmt) + else: + raise ValueError("Expecting dict got: " + repr(val)) + + def _register_place_holder(self, key): + if key in self._place_holder_registry: + self._place_holder_registry[key] += 1 + else: + self._place_holder_registry[key] = 1 + return key + "_" + str(self._place_holder_registry[key]) + + def _parse_q_filters(self, ident, q, source_class): + target = [] + for child in q.children: + if isinstance(child, QBase): + q_childs = self._parse_q_filters(ident, child, source_class) + if child.connector == Q.OR: + q_childs = "(" + q_childs + ")" + target.append(q_childs) + else: + kwargs = {child[0]: child[1]} + filters = process_filter_args(source_class, kwargs) + for prop, op_and_val in filters.items(): + operator, val = op_and_val + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = f"{ident}.{prop} {operator}" + else: + place_holder = self._register_place_holder(ident + "_" + prop) + statement = f"{ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + target.append(statement) + ret = f" {q.connector} ".join(target) + if q.negated: + ret = f"NOT ({ret})" + return ret + + def build_where_stmt(self, ident, filters, q_filters=None, source_class=None): + """ + construct a where statement from some filters + """ + if q_filters is not None: + stmts = self._parse_q_filters(ident, q_filters, source_class) + if stmts: + self._ast.where.append(stmts) + else: + stmts = [] + for row in filters: + negate = False + + # pre-process NOT cases as they are nested dicts + if "__NOT__" in row and len(row) == 1: + negate = True + row = row["__NOT__"] + + for prop, operator_and_val in row.items(): + operator, val = operator_and_val + if operator in _UNARY_OPERATORS: + # unary operators do not have a parameter + statement = ( + f"{'NOT' if negate else ''} {ident}.{prop} {operator}" + ) + else: + place_holder = self._register_place_holder(ident + "_" + prop) + statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}" + self._query_params[place_holder] = val + stmts.append(statement) + + self._ast.where.append(" AND ".join(stmts)) + + def build_query(self): + query = "" + + if self._ast.lookup: + query += self._ast.lookup + + # Instead of using only one MATCH statement for every relation + # to follow, we use one MATCH per relation (to avoid cartesian + # product issues...). + # There might be optimizations to be done, using projections, + # or pusing patterns instead of a chain of OPTIONAL MATCH. + if self._ast.match: + query += " MATCH " + query += " MATCH ".join(i for i in self._ast.match) + + if self._ast.optional_match: + query += " OPTIONAL MATCH " + query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match) + + if self._ast.where: + query += " WHERE " + query += " AND ".join(self._ast.where) + + if self._ast.with_clause: + query += " WITH " + query += self._ast.with_clause + + query += " RETURN " + if self._ast.return_clause: + query += self._ast.return_clause + if self._ast.additional_return: + if self._ast.return_clause: + query += ", " + query += ", ".join(self._ast.additional_return) + + if self._ast.order_by: + query += " ORDER BY " + query += ", ".join(self._ast.order_by) + + # If we return a count with pagination, pagination has to happen before RETURN + # It will then be included in the WITH clause already + if self._ast.skip and not self._ast.is_count: + query += f" SKIP {self._ast.skip}" + + if self._ast.limit and not self._ast.is_count: + query += f" LIMIT {self._ast.limit}" + + return query + + async def _count(self): + self._ast.is_count = True + # If we return a count with pagination, pagination has to happen before RETURN + # Like : WITH my_var SKIP 10 LIMIT 10 RETURN count(my_var) + self._ast.with_clause = f"{self._ast.return_clause}" + if self._ast.skip: + self._ast.with_clause += f" SKIP {self._ast.skip}" + + if self._ast.limit: + self._ast.with_clause += f" LIMIT {self._ast.limit}" + + self._ast.return_clause = f"count({self._ast.return_clause})" + # drop order_by, results in an invalid query + self._ast.order_by = None + # drop additional_return to avoid unexpected result + self._ast.additional_return = None + query = self.build_query() + results, _ = await adb.cypher_query(query, self._query_params) + return int(results[0][0]) + + async def _contains(self, node_element_id): + # inject id = into ast + if not self._ast.return_clause: + print(self._ast.additional_return) + self._ast.return_clause = self._ast.additional_return[0] + ident = self._ast.return_clause + place_holder = self._register_place_holder(ident + "_contains") + self._ast.where.append( + f"{await adb.get_id_method()}({ident}) = ${place_holder}" + ) + self._query_params[place_holder] = node_element_id + return await self._count() >= 1 + + async def _execute(self, lazy=False): + if lazy: + # inject id() into return or return_set + if self._ast.return_clause: + self._ast.return_clause = ( + f"{await adb.get_id_method()}({self._ast.return_clause})" + ) + else: + self._ast.additional_return = [ + f"{await adb.get_id_method()}({item})" + for item in self._ast.additional_return + ] + query = self.build_query() + results, _ = await adb.cypher_query( + query, self._query_params, resolve_objects=True + ) + # The following is not as elegant as it could be but had to be copied from the + # version prior to cypher_query with the resolve_objects capability. + # It seems that certain calls are only supposed to be focusing to the first + # result item returned (?) + if results and len(results[0]) == 1: + return [n[0] for n in results] + return results + + +class AsyncBaseSet: + """ + Base class for all node sets. + + Contains common python magic methods, __len__, __contains__ etc + """ + + query_cls = AsyncQueryBuilder + + async def all(self, lazy=False): + """ + Return all nodes belonging to the set + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :return: list of nodes + :rtype: list + """ + ast = await self.query_cls(self).build_ast() + return await ast._execute(lazy) + + async def __aiter__(self): + ast = await self.query_cls(self).build_ast() + async for i in await ast._execute(): + yield i + + async def get_len(self): + ast = await self.query_cls(self).build_ast() + return await ast._count() + + async def check_bool(self): + """ + Override for __bool__ dunder method. + :return: True if the set contains any nodes, False otherwise + :rtype: bool + """ + ast = await self.query_cls(self).build_ast() + _count = await ast._count() + return _count > 0 + + async def check_nonzero(self): + """ + Override for __bool__ dunder method. + :return: True if the set contains any node, False otherwise + :rtype: bool + """ + return await self.check_bool() + + async def check_contains(self, obj): + if isinstance(obj, AsyncStructuredNode): + if hasattr(obj, "element_id") and obj.element_id is not None: + ast = await self.query_cls(self).build_ast() + obj_element_id = await adb.parse_element_id(obj.element_id) + return await ast._contains(obj_element_id) + raise ValueError("Unsaved node: " + repr(obj)) + + raise ValueError("Expecting StructuredNode instance") + + async def get_item(self, key): + if isinstance(key, slice): + if key.stop and key.start: + self.limit = key.stop - key.start + self.skip = key.start + elif key.stop: + self.limit = key.stop + elif key.start: + self.skip = key.start + + return self + + if isinstance(key, int): + self.skip = key + self.limit = 1 + + ast = await self.query_cls(self).build_ast() + _items = ast._execute() + return _items[0] + + return None + + +@dataclass +class Optional: + """Simple relation qualifier.""" + + relation: str + + +class AsyncNodeSet(AsyncBaseSet): + """ + A class representing as set of nodes matching common query parameters + """ + + def __init__(self, source): + self.source = source # could be a Traverse object or a node class + if isinstance(source, AsyncTraversal): + self.source_class = source.target_class + elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): + self.source_class = source + elif isinstance(source, AsyncStructuredNode): + self.source_class = source.__class__ + else: + raise ValueError("Bad source for nodeset " + repr(source)) + + # setup Traversal objects using relationship definitions + install_traversals(self.source_class, self) + + self.filters = [] + self.q_filters = Q() + + # used by has() + self.must_match = {} + self.dont_match = {} + + self.relations_to_fetch: list = [] + + def __await__(self): + return self.all().__await__() + + async def _get(self, limit=None, lazy=False, **kwargs): + self.filter(**kwargs) + if limit: + self.limit = limit + ast = await self.query_cls(self).build_ast() + return await ast._execute(lazy) + + async def get(self, lazy=False, **kwargs): + """ + Retrieve one node from the set matching supplied parameters + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :param kwargs: same syntax as `filter()` + :return: node + """ + result = await self._get(limit=2, lazy=lazy, **kwargs) + if len(result) > 1: + raise MultipleNodesReturned(repr(kwargs)) + if not result: + raise self.source_class.DoesNotExist(repr(kwargs)) + return result[0] + + async def get_or_none(self, **kwargs): + """ + Retrieve a node from the set matching supplied parameters or return none + + :param kwargs: same syntax as `filter()` + :return: node or none + """ + try: + return await self.get(**kwargs) + except self.source_class.DoesNotExist: + return None + + async def first(self, **kwargs): + """ + Retrieve the first node from the set matching supplied parameters + + :param kwargs: same syntax as `filter()` + :return: node + """ + result = await self._get(limit=1, **kwargs) + if result: + return result[0] + else: + raise self.source_class.DoesNotExist(repr(kwargs)) + + async def first_or_none(self, **kwargs): + """ + Retrieve the first node from the set matching supplied parameters or return none + + :param kwargs: same syntax as `filter()` + :return: node or none + """ + try: + return await self.first(**kwargs) + except self.source_class.DoesNotExist: + pass + return None + + def filter(self, *args, **kwargs): + """ + Apply filters to the existing nodes in the set. + + :param args: a Q object + + e.g `.filter(Q(salary__lt=10000) | Q(salary__gt=20000))`. + + :param kwargs: filter parameters + + Filters mimic Django's syntax with the double '__' to separate field and operators. + + e.g `.filter(salary__gt=20000)` results in `salary > 20000`. + + The following operators are available: + + * 'lt': less than + * 'gt': greater than + * 'lte': less than or equal to + * 'gte': greater than or equal to + * 'ne': not equal to + * 'in': matches one of list (or tuple) + * 'isnull': is null + * 'regex': matches supplied regex (neo4j regex format) + * 'exact': exactly match string (just '=') + * 'iexact': case insensitive match string + * 'contains': contains string + * 'icontains': case insensitive contains + * 'startswith': string starts with + * 'istartswith': case insensitive string starts with + * 'endswith': string ends with + * 'iendswith': case insensitive string ends with + + :return: self + """ + if args or kwargs: + self.q_filters = Q(self.q_filters & Q(*args, **kwargs)) + return self + + def exclude(self, *args, **kwargs): + """ + Exclude nodes from the NodeSet via filters. + + :param kwargs: filter parameters see syntax for the filter method + :return: self + """ + if args or kwargs: + self.q_filters = Q(self.q_filters & ~Q(*args, **kwargs)) + return self + + def has(self, **kwargs): + must_match, dont_match = process_has_args(self.source_class, kwargs) + self.must_match.update(must_match) + self.dont_match.update(dont_match) + return self + + def order_by(self, *props): + """ + Order by properties. Prepend with minus to do descending. Pass None to + remove ordering. + """ + should_remove = len(props) == 1 and props[0] is None + if not hasattr(self, "order_by_elements") or should_remove: + self.order_by_elements = [] + if should_remove: + return self + if "?" in props: + self.order_by_elements.append("?") + else: + for prop in props: + prop = prop.strip() + if prop.startswith("-"): + prop = prop[1:] + desc = True + else: + desc = False + + if prop not in self.source_class.defined_properties(rels=False): + raise ValueError( + f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation." + ) + + property_obj = getattr(self.source_class, prop) + if isinstance(property_obj, AliasProperty): + prop = property_obj.aliased_to() + + self.order_by_elements.append(prop + (" DESC" if desc else "")) + + return self + + def fetch_relations(self, *relation_names): + """Specify a set of relations to return.""" + relations = [] + for relation_name in relation_names: + if isinstance(relation_name, Optional): + item = {"path": relation_name.relation, "optional": True} + else: + item = {"path": relation_name} + relations.append(item) + self.relations_to_fetch = relations + return self + + +class AsyncTraversal(AsyncBaseSet): + """ + Models a traversal from a node to another. + + :param source: Starting of the traversal. + :type source: A :class:`~neomodel.core.StructuredNode` subclass, an + instance of such, a :class:`~neomodel.match.NodeSet` instance + or a :class:`~neomodel.match.Traversal` instance. + :param name: A name for the traversal. + :type name: :class:`str` + :param definition: A relationship definition that most certainly deserves + a documentation here. + :type defintion: :class:`dict` + """ + + def __await__(self): + return self.all().__await__() + + def __init__(self, source, name, definition): + """ + Create a traversal + + """ + self.source = source + + if isinstance(source, AsyncTraversal): + self.source_class = source.target_class + elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode): + self.source_class = source + elif isinstance(source, AsyncStructuredNode): + self.source_class = source.__class__ + elif isinstance(source, AsyncNodeSet): + self.source_class = source.source_class + else: + raise TypeError(f"Bad source for traversal: {type(source)}") + + invalid_keys = set(definition) - { + "direction", + "model", + "node_class", + "relation_type", + } + if invalid_keys: + raise ValueError(f"Prohibited keys in Traversal definition: {invalid_keys}") + + self.definition = definition + self.target_class = definition["node_class"] + self.name = name + self.filters = [] + + def match(self, **kwargs): + """ + Traverse relationships with properties matching the given parameters. + + e.g: `.match(price__lt=10)` + + :param kwargs: see `NodeSet.filter()` for syntax + :return: self + """ + if kwargs: + if self.definition.get("model") is None: + raise ValueError( + "match() with filter only available on relationships with a model" + ) + output = process_filter_args(self.definition["model"], kwargs) + if output: + self.filters.append(output) + return self diff --git a/neomodel/async_/path.py b/neomodel/async_/path.py new file mode 100644 index 00000000..6128347e --- /dev/null +++ b/neomodel/async_/path.py @@ -0,0 +1,53 @@ +from neo4j.graph import Path + +from neomodel.async_.core import adb +from neomodel.async_.relationship import AsyncStructuredRel + + +class AsyncNeomodelPath(Path): + """ + Represents paths within neomodel. + + This object is instantiated when you include whole paths in your ``cypher_query()`` + result sets and turn ``resolve_objects`` to True. + + That is, any query of the form: + :: + + MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p + + ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already + resolved to their neomodel objects if such mapping is possible. + + + :param nodes: Neomodel nodes appearing in the path in order of appearance. + :param relationships: Neomodel relationships appearing in the path in order of appearance. + :type nodes: List[StructuredNode] + :type relationships: List[StructuredRel] + """ + + def __init__(self, a_neopath): + self._nodes = [] + self._relationships = [] + + for a_node in a_neopath.nodes: + self._nodes.append(adb._object_resolution(a_node)) + + for a_relationship in a_neopath.relationships: + # This check is required here because if the relationship does not bear data + # then it does not have an entry in the registry. In that case, we instantiate + # an "unspecified" StructuredRel. + rel_type = frozenset([a_relationship.type]) + if rel_type in adb._NODE_CLASS_REGISTRY: + new_rel = adb._object_resolution(a_relationship) + else: + new_rel = AsyncStructuredRel.inflate(a_relationship) + self._relationships.append(new_rel) + + @property + def nodes(self): + return self._nodes + + @property + def relationships(self): + return self._relationships diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py new file mode 100644 index 00000000..b9401dab --- /dev/null +++ b/neomodel/async_/property_manager.py @@ -0,0 +1,109 @@ +import types + +from neomodel.exceptions import RequiredProperty +from neomodel.properties import AliasProperty, Property + + +def display_for(key): + def display_choice(self): + return getattr(self.__class__, key).choices[getattr(self, key)] + + return display_choice + + +class AsyncPropertyManager: + """ + Common methods for handling properties on node and relationship objects. + """ + + def __init__(self, **kwargs): + properties = getattr(self, "__all_properties__", None) + if properties is None: + properties = self.defined_properties(rels=False, aliases=False).items() + for name, property in properties: + if kwargs.get(name) is None: + if getattr(property, "has_default", False): + setattr(self, name, property.default_value()) + else: + setattr(self, name, None) + else: + setattr(self, name, kwargs[name]) + + if getattr(property, "choices", None): + setattr( + self, + f"get_{name}_display", + types.MethodType(display_for(name), self), + ) + + if name in kwargs: + del kwargs[name] + + aliases = getattr(self, "__all_aliases__", None) + if aliases is None: + aliases = self.defined_properties( + aliases=True, rels=False, properties=False + ).items() + for name, property in aliases: + if name in kwargs: + setattr(self, name, kwargs[name]) + del kwargs[name] + + # undefined properties (for magic @prop.setters etc) + for name, property in kwargs.items(): + setattr(self, name, property) + + @property + def __properties__(self): + from neomodel.async_.relationship_manager import AsyncRelationshipManager + + return dict( + (name, value) + for name, value in vars(self).items() + if not name.startswith("_") + and not callable(value) + and not isinstance( + value, + ( + AsyncRelationshipManager, + AliasProperty, + ), + ) + ) + + @classmethod + def deflate(cls, properties, obj=None, skip_empty=False): + # deflate dict ready to be stored + deflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.db_property or name + if properties.get(name) is not None: + deflated[db_property] = property.deflate(properties[name], obj) + elif property.has_default: + deflated[db_property] = property.deflate(property.default_value(), obj) + elif property.required: + raise RequiredProperty(name, cls) + elif not skip_empty: + deflated[db_property] = None + return deflated + + @classmethod + def defined_properties(cls, aliases=True, properties=True, rels=True): + from neomodel.async_.relationship_manager import AsyncRelationshipDefinition + + props = {} + for baseclass in reversed(cls.__mro__): + props.update( + dict( + (name, property) + for name, property in vars(baseclass).items() + if (aliases and isinstance(property, AliasProperty)) + or ( + properties + and isinstance(property, Property) + and not isinstance(property, AliasProperty) + ) + or (rels and isinstance(property, AsyncRelationshipDefinition)) + ) + ) + return props diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py new file mode 100644 index 00000000..5653137a --- /dev/null +++ b/neomodel/async_/relationship.py @@ -0,0 +1,168 @@ +from neomodel.async_.core import adb +from neomodel.async_.property_manager import AsyncPropertyManager +from neomodel.hooks import hooks +from neomodel.properties import Property + +ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + + +class RelationshipMeta(type): + def __new__(mcs, name, bases, dct): + inst = super().__new__(mcs, name, bases, dct) + for key, value in dct.items(): + if issubclass(value.__class__, Property): + if key == "source" or key == "target": + raise ValueError( + "Property names 'source' and 'target' are not allowed as they conflict with neomodel internals." + ) + elif key == "id": + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif key == "element_id": + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + value.name = key + value.owner = inst + + # support for 'magic' properties + if hasattr(value, "setup") and hasattr(value.setup, "__call__"): + value.setup() + return inst + + +StructuredRelBase = RelationshipMeta("RelationshipBase", (AsyncPropertyManager,), {}) + + +class AsyncStructuredRel(StructuredRelBase): + """ + Base class for relationship objects + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def element_id(self): + if hasattr(self, "element_id_property"): + return self.element_id_property + + @property + def _start_node_element_id(self): + if hasattr(self, "_start_node_element_id_property"): + return self._start_node_element_id_property + + @property + def _end_node_element_id(self): + if hasattr(self, "_end_node_element_id_property"): + return self._end_node_element_id_property + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self): + try: + return int(self.element_id_property) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc + + # Version 4.4 support - id is deprecated in version 5.x + @property + def _start_node_id(self): + try: + return int(self._start_node_element_id_property) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc + + # Version 4.4 support - id is deprecated in version 5.x + @property + def _end_node_id(self): + try: + return int(self._end_node_element_id_property) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc + + @hooks + async def save(self): + """ + Save the relationship + + :return: self + """ + props = self.deflate(self.__properties__) + query = f"MATCH ()-[r]->() WHERE {await adb.get_id_method()}(r)=$self " + query += "".join([f" SET r.{key} = ${key}" for key in props]) + props["self"] = await adb.parse_element_id(self.element_id) + + await adb.cypher_query(query, props) + + return self + + async def start_node(self): + """ + Get start node + + :return: StructuredNode + """ + results = await adb.cypher_query( + f""" + MATCH (aNode) + WHERE {await adb.get_id_method()}(aNode)=$start_node_element_id + RETURN aNode + """, + { + "start_node_element_id": await adb.parse_element_id( + self._start_node_element_id + ) + }, + resolve_objects=True, + ) + return results[0][0][0] + + async def end_node(self): + """ + Get end node + + :return: StructuredNode + """ + results = await adb.cypher_query( + f""" + MATCH (aNode) + WHERE {await adb.get_id_method()}(aNode)=$end_node_element_id + RETURN aNode + """, + { + "end_node_element_id": await adb.parse_element_id( + self._end_node_element_id + ) + }, + resolve_objects=True, + ) + return results[0][0][0] + + @classmethod + def inflate(cls, rel): + """ + Inflate a neo4j_driver relationship object to a neomodel object + :param rel: + :return: StructuredRel + """ + props = {} + for key, prop in cls.defined_properties(aliases=False, rels=False).items(): + if key in rel: + props[key] = prop.inflate(rel[key], obj=rel) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + srel = cls(**props) + srel._start_node_element_id_property = rel.start_node.element_id + srel._end_node_element_id_property = rel.end_node.element_id + srel.element_id_property = rel.element_id + return srel diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py new file mode 100644 index 00000000..9c5e7398 --- /dev/null +++ b/neomodel/async_/relationship_manager.py @@ -0,0 +1,561 @@ +import functools +import inspect +import sys +from importlib import import_module + +from neomodel.async_.core import adb +from neomodel.async_.match import ( + AsyncNodeSet, + AsyncTraversal, + _rel_helper, + _rel_merge_helper, +) +from neomodel.async_.relationship import AsyncStructuredRel +from neomodel.exceptions import NotConnected, RelationshipClassRedefined +from neomodel.util import ( + EITHER, + INCOMING, + OUTGOING, + _get_node_properties, + enumerate_traceback, +) + +# basestring python 3.x fallback +try: + basestring +except NameError: + basestring = str + + +# check source node is saved and not deleted +def check_source(fn): + fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ + + @functools.wraps(fn) + def checker(self, *args, **kwargs): + self.source._pre_action_check(self.name + "." + fn_name) + return fn(self, *args, **kwargs) + + return checker + + +# checks if obj is a direct subclass, 1 level +def is_direct_subclass(obj, classinfo): + for base in obj.__bases__: + if base == classinfo: + return True + return False + + +class AsyncRelationshipManager(object): + """ + Base class for all relationships managed through neomodel. + + I.e the 'friends' object in `user.friends.all()` + """ + + def __init__(self, source, key, definition): + self.source = source + self.source_class = source.__class__ + self.name = key + self.definition = definition + + def __str__(self): + direction = "either" + if self.definition["direction"] == OUTGOING: + direction = "a outgoing" + elif self.definition["direction"] == INCOMING: + direction = "a incoming" + + return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" + + def __await__(self): + return self.all().__await__() + + def _check_node(self, obj): + """check for valid node i.e correct class and is saved""" + if not issubclass(type(obj), self.definition["node_class"]): + raise ValueError( + "Expected node of class " + self.definition["node_class"].__name__ + ) + if not hasattr(obj, "element_id"): + raise ValueError("Can't perform operation on unsaved node " + repr(obj)) + + @check_source + async def connect(self, node, properties=None): + """ + Connect a node + + :param node: + :param properties: for the new relationship + :type: dict + :return: + """ + self._check_node(node) + + if not self.definition["model"] and properties: + raise NotImplementedError( + "Relationship properties without using a relationship model " + "is no longer supported." + ) + + params = {} + rel_model = self.definition["model"] + rel_prop = None + + if rel_model: + rel_prop = {} + # need to generate defaults etc to create fake instance + tmp = rel_model(**properties) if properties else rel_model() + # build params and place holders to pass to rel_helper + for prop, val in rel_model.deflate(tmp.__properties__).items(): + if val is not None: + rel_prop[prop] = "$" + prop + else: + rel_prop[prop] = None + params[prop] = val + + if hasattr(tmp, "pre_save"): + tmp.pre_save() + + new_rel = _rel_merge_helper( + lhs="us", + rhs="them", + ident="r", + relation_properties=rel_prop, + **self.definition, + ) + q = ( + f"MATCH (them), (us) WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self " + "MERGE" + new_rel + ) + + params["them"] = await adb.parse_element_id(node.element_id) + + if not rel_model: + await self.source.cypher(q, params) + return True + + results = await self.source.cypher(q + " RETURN r", params) + rel_ = results[0][0][0] + rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node) + + if hasattr(rel_instance, "post_save"): + rel_instance.post_save() + + return rel_instance + + @check_source + async def replace(self, node, properties=None): + """ + Disconnect all existing nodes and connect the supplied node + + :param node: + :param properties: for the new relationship + :type: dict + :return: + """ + await self.disconnect_all() + await self.connect(node, properties) + + @check_source + async def relationship(self, node): + """ + Retrieve the relationship object for this first relationship between self and node. + + :param node: + :return: StructuredRel + """ + self._check_node(node) + my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) + q = ( + "MATCH " + + my_rel + + f" WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r LIMIT 1" + ) + results = await self.source.cypher( + q, {"them": await adb.parse_element_id(node.element_id)} + ) + rels = results[0] + if not rels: + return + + rel_model = self.definition.get("model") or AsyncStructuredRel + + return self._set_start_end_cls(rel_model.inflate(rels[0][0]), node) + + @check_source + async def all_relationships(self, node): + """ + Retrieve all relationship objects between self and node. + + :param node: + :return: [StructuredRel] + """ + self._check_node(node) + + my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) + q = f"MATCH {my_rel} WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r " + results = await self.source.cypher( + q, {"them": await adb.parse_element_id(node.element_id)} + ) + rels = results[0] + if not rels: + return [] + + rel_model = self.definition.get("model") or AsyncStructuredRel + return [ + self._set_start_end_cls(rel_model.inflate(rel[0]), node) for rel in rels + ] + + def _set_start_end_cls(self, rel_instance, obj): + if self.definition["direction"] == INCOMING: + rel_instance._start_node_class = obj.__class__ + rel_instance._end_node_class = self.source_class + else: + rel_instance._start_node_class = self.source_class + rel_instance._end_node_class = obj.__class__ + return rel_instance + + @check_source + async def reconnect(self, old_node, new_node): + """ + Disconnect old_node and connect new_node copying over any properties on the original relationship. + + Useful for preventing cardinality violations + + :param old_node: + :param new_node: + :return: None + """ + + self._check_node(old_node) + self._check_node(new_node) + if old_node.element_id == new_node.element_id: + return + old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition) + + # get list of properties on the existing rel + old_node_element_id = await adb.parse_element_id(old_node.element_id) + new_node_element_id = await adb.parse_element_id(new_node.element_id) + result, _ = await self.source.cypher( + f""" + MATCH (us), (old) WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old + MATCH {old_rel} RETURN r + """, + {"old": old_node_element_id}, + ) + if result: + node_properties = _get_node_properties(result[0][0]) + existing_properties = node_properties.keys() + else: + raise NotConnected("reconnect", self.source, old_node) + + # remove old relationship and create new one + new_rel = _rel_merge_helper(lhs="us", rhs="new", ident="r2", **self.definition) + q = ( + "MATCH (us), (old), (new) " + f"WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old and {await adb.get_id_method()}(new)=$new " + "MATCH " + old_rel + ) + q += " MERGE" + new_rel + + # copy over properties if we have + q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) + q += " WITH r DELETE r" + + await self.source.cypher( + q, {"old": old_node_element_id, "new": new_node_element_id} + ) + + @check_source + async def disconnect(self, node): + """ + Disconnect a node + + :param node: + :return: + """ + rel = _rel_helper(lhs="a", rhs="b", ident="r", **self.definition) + q = f""" + MATCH (a), (b) WHERE {await adb.get_id_method()}(a)=$self and {await adb.get_id_method()}(b)=$them + MATCH {rel} DELETE r + """ + await self.source.cypher( + q, {"them": await adb.parse_element_id(node.element_id)} + ) + + @check_source + async def disconnect_all(self): + """ + Disconnect all nodes + + :return: + """ + rhs = "b:" + self.definition["node_class"].__label__ + rel = _rel_helper(lhs="a", rhs=rhs, ident="r", **self.definition) + q = ( + f"MATCH (a) WHERE {await adb.get_id_method()}(a)=$self MATCH " + + rel + + " DELETE r" + ) + await self.source.cypher(q) + + @check_source + def _new_traversal(self): + return AsyncTraversal(self.source, self.name, self.definition) + + # The methods below simply proxy the match engine. + def get(self, **kwargs): + """ + Retrieve a related node with the matching node properties. + + :param kwargs: same syntax as `NodeSet.filter()` + :return: node + """ + return AsyncNodeSet(self._new_traversal()).get(**kwargs) + + def get_or_none(self, **kwargs): + """ + Retrieve a related node with the matching node properties or return None. + + :param kwargs: same syntax as `NodeSet.filter()` + :return: node + """ + return AsyncNodeSet(self._new_traversal()).get_or_none(**kwargs) + + def filter(self, *args, **kwargs): + """ + Retrieve related nodes matching the provided properties. + + :param args: a Q object + :param kwargs: same syntax as `NodeSet.filter()` + :return: NodeSet + """ + return AsyncNodeSet(self._new_traversal()).filter(*args, **kwargs) + + def order_by(self, *props): + """ + Order related nodes by specified properties + + :param props: + :return: NodeSet + """ + return AsyncNodeSet(self._new_traversal()).order_by(*props) + + def exclude(self, *args, **kwargs): + """ + Exclude nodes that match the provided properties. + + :param args: a Q object + :param kwargs: same syntax as `NodeSet.filter()` + :return: NodeSet + """ + return AsyncNodeSet(self._new_traversal()).exclude(*args, **kwargs) + + async def is_connected(self, node): + """ + Check if a node is connected with this relationship type + :param node: + :return: bool + """ + return await self._new_traversal().check_contains(node) + + async def single(self): + """ + Get a single related node or none. + + :return: StructuredNode + """ + try: + rels = await self + return rels[0] + except IndexError: + pass + + def match(self, **kwargs): + """ + Return set of nodes who's relationship properties match supplied args + + :param kwargs: same syntax as `NodeSet.filter()` + :return: NodeSet + """ + return self._new_traversal().match(**kwargs) + + async def all(self): + """ + Return all related nodes. + + :return: list + """ + return await self._new_traversal().all() + + async def __aiter__(self): + return self._new_traversal().__aiter__() + + async def get_len(self): + return await self._new_traversal().get_len() + + async def check_bool(self): + return await self._new_traversal().check_bool() + + async def check_nonzero(self): + return self._new_traversal().check_nonzero() + + async def check_contains(self, obj): + return self._new_traversal().check_contains(obj) + + async def get_item(self, key): + return self._new_traversal().get_item(key) + + +class AsyncRelationshipDefinition: + def __init__( + self, + relation_type, + cls_name, + direction, + manager=AsyncRelationshipManager, + model=None, + ): + self._validate_class(cls_name, model) + + current_frame = inspect.currentframe() + + frame_number = 3 + for i, frame in enumerate_traceback(current_frame): + if cls_name in frame.f_globals: + frame_number = i + break + self.module_name = sys._getframe(frame_number).f_globals["__name__"] + if "__file__" in sys._getframe(frame_number).f_globals: + self.module_file = sys._getframe(frame_number).f_globals["__file__"] + self._raw_class = cls_name + self.manager = manager + self.definition = { + "relation_type": relation_type, + "direction": direction, + "model": model, + } + + if model is not None: + # Relationships are easier to instantiate because + # they cannot have multiple labels. + # So, a relationship's type determines the class that should be + # instantiated uniquely. + # Here however, we still use a `frozenset([relation_type])` + # to preserve the mapping type. + label_set = frozenset([relation_type]) + try: + # If the relationship mapping exists then it is attempted + # to be redefined so that it applies to the same label. + # In this case, it has to be ensured that the class + # that is overriding the relationship is a descendant + # of the already existing class. + model_from_registry = adb._NODE_CLASS_REGISTRY[label_set] + if not issubclass(model, model_from_registry): + is_parent = issubclass(model_from_registry, model) + if is_direct_subclass(model, AsyncStructuredRel) and not is_parent: + raise RelationshipClassRedefined( + relation_type, adb._NODE_CLASS_REGISTRY, model + ) + else: + adb._NODE_CLASS_REGISTRY[label_set] = model + except KeyError: + # If the mapping does not exist then it is simply created. + adb._NODE_CLASS_REGISTRY[label_set] = model + + def _validate_class(self, cls_name, model): + if not isinstance(cls_name, (basestring, object)): + raise ValueError("Expected class name or class got " + repr(cls_name)) + + if model and not issubclass(model, (AsyncStructuredRel,)): + raise ValueError("model must be a StructuredRel") + + def lookup_node_class(self): + if not isinstance(self._raw_class, basestring): + self.definition["node_class"] = self._raw_class + else: + name = self._raw_class + if name.find(".") == -1: + module = self.module_name + else: + module, _, name = name.rpartition(".") + + if module not in sys.modules: + # yet another hack to get around python semantics + # __name__ is the namespace of the parent module for __init__.py files, + # and the namespace of the current module for other .py files, + # therefore there's a need to define the namespace differently for + # these two cases in order for . in relative imports to work correctly + # (i.e. to mean the same thing for both cases). + # For example in the comments below, namespace == myapp, always + if not hasattr(self, "module_file"): + raise ImportError(f"Couldn't lookup '{name}'") + + if "__init__.py" in self.module_file: + # e.g. myapp/__init__.py -[__name__]-> myapp + namespace = self.module_name + else: + # e.g. myapp/models.py -[__name__]-> myapp.models + namespace = self.module_name.rpartition(".")[0] + + # load a module from a namespace (e.g. models from myapp) + if module: + module = import_module(module, namespace).__name__ + # load the namespace itself (e.g. myapp) + # (otherwise it would look like import . from myapp) + else: + module = import_module(namespace).__name__ + self.definition["node_class"] = getattr(sys.modules[module], name) + + def build_manager(self, source, name): + self.lookup_node_class() + return self.manager(source, name, self.definition) + + +class AsyncZeroOrMore(AsyncRelationshipManager): + """ + A relationship of zero or more nodes (the default) + """ + + description = "zero or more relationships" + + +class AsyncRelationshipTo(AsyncRelationshipDefinition): + def __init__( + self, + cls_name, + relation_type, + cardinality=AsyncZeroOrMore, + model=None, + ): + super().__init__( + relation_type, cls_name, OUTGOING, manager=cardinality, model=model + ) + + +class AsyncRelationshipFrom(AsyncRelationshipDefinition): + def __init__( + self, + cls_name, + relation_type, + cardinality=AsyncZeroOrMore, + model=None, + ): + super().__init__( + relation_type, cls_name, INCOMING, manager=cardinality, model=model + ) + + +class AsyncRelationship(AsyncRelationshipDefinition): + def __init__( + self, + cls_name, + relation_type, + cardinality=AsyncZeroOrMore, + model=None, + ): + super().__init__( + relation_type, cls_name, EITHER, manager=cardinality, model=model + ) diff --git a/neomodel/config.py b/neomodel/config.py index b54aa806..85c9ed8a 100644 --- a/neomodel/config.py +++ b/neomodel/config.py @@ -1,8 +1,6 @@ import neo4j -from ._version import __version__ - -AUTO_INSTALL_LABELS = False +from neomodel._version import __version__ # Use this to connect with automatically created driver # The following options are the default ones that will be used as driver config @@ -24,3 +22,6 @@ # DRIVER = neo4j.GraphDatabase().driver( # "bolt://localhost:7687", auth=("neo4j", "foobarbaz") # ) +DRIVER = None +# Use this to connect to a specific database when using the self-managed driver +DATABASE_NAME = None diff --git a/neomodel/contrib/__init__.py b/neomodel/contrib/__init__.py index 3be00b41..a852965d 100644 --- a/neomodel/contrib/__init__.py +++ b/neomodel/contrib/__init__.py @@ -1 +1,2 @@ -from .semi_structured import SemiStructuredNode +from neomodel.contrib.async_.semi_structured import AsyncSemiStructuredNode +from neomodel.contrib.sync_.semi_structured import SemiStructuredNode diff --git a/neomodel/contrib/async_/semi_structured.py b/neomodel/contrib/async_/semi_structured.py new file mode 100644 index 00000000..c333ae0e --- /dev/null +++ b/neomodel/contrib/async_/semi_structured.py @@ -0,0 +1,64 @@ +from neomodel.async_.core import AsyncStructuredNode +from neomodel.exceptions import DeflateConflict, InflateConflict +from neomodel.util import _get_node_properties + + +class AsyncSemiStructuredNode(AsyncStructuredNode): + """ + A base class allowing properties to be stored on a node that aren't + specified in its definition. Conflicting properties are signaled with the + :class:`DeflateConflict` exception:: + + class Person(AsyncSemiStructuredNode): + name = StringProperty() + age = IntegerProperty() + + def hello(self): + print("Hi my names " + self.name) + + tim = await Person(name='Tim', age=8, weight=11).save() + tim.hello = "Hi" + await tim.save() # DeflateConflict + """ + + __abstract_node__ = True + + @classmethod + def inflate(cls, node): + # support lazy loading + if isinstance(node, str) or isinstance(node, int): + snode = cls() + snode.element_id_property = node + else: + props = {} + node_properties = {} + for key, prop in cls.__all_properties__: + node_properties = _get_node_properties(node) + if key in node_properties: + props[key] = prop.inflate(node_properties[key], node) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + # handle properties not defined on the class + for free_key in (x for x in node_properties if x not in props): + if hasattr(cls, free_key): + raise InflateConflict( + cls, free_key, node_properties[free_key], node.element_id + ) + props[free_key] = node_properties[free_key] + + snode = cls(**props) + snode.element_id_property = node.element_id + + return snode + + @classmethod + def deflate(cls, node_props, obj=None, skip_empty=False): + deflated = super().deflate(node_props, obj, skip_empty=skip_empty) + for key in [k for k in node_props if k not in deflated]: + if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): + raise DeflateConflict(cls, key, deflated[key], obj.element_id) + + node_props.update(deflated) + return node_props diff --git a/neomodel/contrib/spatial_properties.py b/neomodel/contrib/spatial_properties.py index 7a48018b..982e6921 100644 --- a/neomodel/contrib/spatial_properties.py +++ b/neomodel/contrib/spatial_properties.py @@ -25,9 +25,9 @@ # If shapely is not installed, its import will fail and the spatial properties will not be available try: - from shapely.geometry import Point as ShapelyPoint from shapely import __version__ as shapely_version - from shapely.coords import CoordinateSequence + from shapely.coords import CoordinateSequence + from shapely.geometry import Point as ShapelyPoint except ImportError as exc: raise ImportError( "NEOMODEL ERROR: Shapely not found. If required, you can install Shapely via " @@ -53,10 +53,11 @@ # Taking into account the Shapely 2.0.0 changes in the way POINT objects are initialisd. if int("".join(shapely_version.split(".")[0:3])) < 200: + class NeomodelPoint(ShapelyPoint): """ Abstracts the Point spatial data type of Neo4j. - + Note: At the time of writing, Neo4j supports 2 main variants of Point: 1. A generic point defined over a Cartesian plane @@ -65,12 +66,12 @@ class NeomodelPoint(ShapelyPoint): * The minimum data to define a point is longitude, latitude [,Height] and the crs is then assumed to be "wgs-84". """ - + # def __init__(self, *args, crs=None, x=None, y=None, z=None, latitude=None, longitude=None, height=None, **kwargs): def __init__(self, *args, **kwargs): """ Creates a NeomodelPoint. - + :param args: Positional arguments to emulate the behaviour of Shapely's Point (and specifically the copy constructor) :type args: list @@ -91,7 +92,7 @@ def __init__(self, *args, **kwargs): :param kwargs: Dictionary of keyword arguments :type kwargs: dict """ - + # Python2.7 Workaround for the order that the arguments get passed to the functions crs = kwargs.pop("crs", None) x = kwargs.pop("x", None) @@ -100,14 +101,16 @@ def __init__(self, *args, **kwargs): longitude = kwargs.pop("longitude", None) latitude = kwargs.pop("latitude", None) height = kwargs.pop("height", None) - + _x, _y, _z = None, None, None - + # CRS validity check is common to both types of constructors that follow if crs is not None and crs not in ACCEPTABLE_CRS: - raise ValueError(f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}") + raise ValueError( + f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}" + ) self._crs = crs - + # If positional arguments have been supplied, then this is a possible call to the copy constructor or # initialisation by a coordinate iterable as per ShapelyPoint constructor. if len(args) > 0: @@ -115,7 +118,9 @@ def __init__(self, *args, **kwargs): if isinstance(args[0], (tuple, list)): # Check dimensionality of tuple if len(args[0]) < 2 or len(args[0]) > 3: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}" + ) x = args[0][0] y = args[0][1] if len(args[0]) == 3: @@ -143,11 +148,15 @@ def __init__(self, *args, **kwargs): if crs is None: self._crs = "cartesian-3d" else: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}" + ) return else: - raise TypeError(f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}") - + raise TypeError( + f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}" + ) + # Initialisation is either via x,y[,z] XOR longitude,latitude[,height]. Specifying both leads to an error. if any(i is not None for i in [x, y, z]) and any( i is not None for i in [latitude, longitude, height] @@ -157,14 +166,14 @@ def __init__(self, *args, **kwargs): "A Point can be defined either by x,y,z coordinates OR latitude,longitude,height but not " "a combination of these terms" ) - + # Specifying no initialisation argument at this point in the constructor is flagged as an error if all(i is None for i in [x, y, z, latitude, longitude, height]): raise ValueError( "Invalid instantiation via no arguments. " "A Point needs default values either in x,y,z or longitude, latitude, height coordinates" ) - + # Geographical Point Initialisation if latitude is not None and longitude is not None: if height is not None: @@ -176,7 +185,7 @@ def __init__(self, *args, **kwargs): self._crs = "wgs-84" _x = longitude _y = latitude - + # Geometrical Point Initialisation if x is not None and y is not None: if z is not None: @@ -188,22 +197,26 @@ def __init__(self, *args, **kwargs): self._crs = "cartesian" _x = x _y = y - + if _z is None: if "-3d" not in self._crs: super().__init__((float(_x), float(_y)), **kwargs) else: - raise ValueError(f"Invalid vector dimensions(2) for given CRS({self._crs}).") + raise ValueError( + f"Invalid vector dimensions(2) for given CRS({self._crs})." + ) else: if "-3d" in self._crs: super().__init__((float(_x), float(_y), float(_z)), **kwargs) else: - raise ValueError(f"Invalid vector dimensions(3) for given CRS({self._crs}).") - + raise ValueError( + f"Invalid vector dimensions(3) for given CRS({self._crs})." + ) + @property def crs(self): return self._crs - + @property def x(self): if not self._crs.startswith("cartesian"): @@ -211,7 +224,7 @@ def x(self): f'Invalid coordinate ("x") for points defined over {self.crs}' ) return super().x - + @property def y(self): if not self._crs.startswith("cartesian"): @@ -219,7 +232,7 @@ def y(self): f'Invalid coordinate ("y") for points defined over {self.crs}' ) return super().y - + @property def z(self): if self._crs != "cartesian-3d": @@ -227,7 +240,7 @@ def z(self): f'Invalid coordinate ("z") for points defined over {self.crs}' ) return super().z - + @property def latitude(self): if not self._crs.startswith("wgs-84"): @@ -235,7 +248,7 @@ def latitude(self): f'Invalid coordinate ("latitude") for points defined over {self.crs}' ) return super().y - + @property def longitude(self): if not self._crs.startswith("wgs-84"): @@ -243,7 +256,7 @@ def longitude(self): f'Invalid coordinate ("longitude") for points defined over {self.crs}' ) return super().x - + @property def height(self): if self._crs != "wgs-84-3d": @@ -251,21 +264,22 @@ def height(self): f'Invalid coordinate ("height") for points defined over {self.crs}' ) return super().z - + # The following operations are necessary here due to the way queries (and more importantly their parameters) get # combined and evaluated in neomodel. Specifically, query expressions get duplicated with deep copies and any valid # datatype values should also implement these operations. def __copy__(self): return NeomodelPoint(self) - + def __deepcopy__(self, memo): return NeomodelPoint(self) else: + class NeomodelPoint: """ Abstracts the Point spatial data type of Neo4j. - + Note: At the time of writing, Neo4j supports 2 main variants of Point: 1. A generic point defined over a Cartesian plane @@ -274,12 +288,12 @@ class NeomodelPoint: * The minimum data to define a point is longitude, latitude [,Height] and the crs is then assumed to be "wgs-84". """ - + # def __init__(self, *args, crs=None, x=None, y=None, z=None, latitude=None, longitude=None, height=None, **kwargs): def __init__(self, *args, **kwargs): """ Creates a NeomodelPoint. - + :param args: Positional arguments to emulate the behaviour of Shapely's Point (and specifically the copy constructor) :type args: list @@ -300,7 +314,7 @@ def __init__(self, *args, **kwargs): :param kwargs: Dictionary of keyword arguments :type kwargs: dict """ - + # Python2.7 Workaround for the order that the arguments get passed to the functions crs = kwargs.pop("crs", None) x = kwargs.pop("x", None) @@ -309,14 +323,16 @@ def __init__(self, *args, **kwargs): longitude = kwargs.pop("longitude", None) latitude = kwargs.pop("latitude", None) height = kwargs.pop("height", None) - + _x, _y, _z = None, None, None - + # CRS validity check is common to both types of constructors that follow if crs is not None and crs not in ACCEPTABLE_CRS: - raise ValueError(f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}") + raise ValueError( + f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}" + ) self._crs = crs - + # If positional arguments have been supplied, then this is a possible call to the copy constructor or # initialisation by a coordinate iterable as per ShapelyPoint constructor. if len(args) > 0: @@ -324,7 +340,9 @@ def __init__(self, *args, **kwargs): if isinstance(args[0], (tuple, list)): # Check dimensionality of tuple if len(args[0]) < 2 or len(args[0]) > 3: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}" + ) x = args[0][0] y = args[0][1] if len(args[0]) == 3: @@ -354,11 +372,15 @@ def __init__(self, *args, **kwargs): if crs is None: self._crs = "cartesian-3d" else: - raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}") + raise ValueError( + f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}" + ) return else: - raise TypeError(f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}") - + raise TypeError( + f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}" + ) + # Initialisation is either via x,y[,z] XOR longitude,latitude[,height]. Specifying both leads to an error. if any(i is not None for i in [x, y, z]) and any( i is not None for i in [latitude, longitude, height] @@ -368,14 +390,14 @@ def __init__(self, *args, **kwargs): "A Point can be defined either by x,y,z coordinates OR latitude,longitude,height but not " "a combination of these terms" ) - + # Specifying no initialisation argument at this point in the constructor is flagged as an error if all(i is None for i in [x, y, z, latitude, longitude, height]): raise ValueError( "Invalid instantiation via no arguments. " "A Point needs default values either in x,y,z or longitude, latitude, height coordinates" ) - + # Geographical Point Initialisation if latitude is not None and longitude is not None: if height is not None: @@ -387,7 +409,7 @@ def __init__(self, *args, **kwargs): self._crs = "wgs-84" _x = longitude _y = latitude - + # Geometrical Point Initialisation if x is not None and y is not None: if z is not None: @@ -399,23 +421,28 @@ def __init__(self, *args, **kwargs): self._crs = "cartesian" _x = x _y = y - + if _z is None: if "-3d" not in self._crs: self._shapely_point = ShapelyPoint((float(_x), float(_y))) else: - raise ValueError(f"Invalid vector dimensions(2) for given CRS({self._crs}).") + raise ValueError( + f"Invalid vector dimensions(2) for given CRS({self._crs})." + ) else: if "-3d" in self._crs: - self._shapely_point = ShapelyPoint((float(_x), float(_y), float(_z))) + self._shapely_point = ShapelyPoint( + (float(_x), float(_y), float(_z)) + ) else: - raise ValueError(f"Invalid vector dimensions(3) for given CRS({self._crs}).") + raise ValueError( + f"Invalid vector dimensions(3) for given CRS({self._crs})." + ) - @property def crs(self): return self._crs - + @property def x(self): if not self._crs.startswith("cartesian"): @@ -423,7 +450,7 @@ def x(self): f'Invalid coordinate ("x") for points defined over {self.crs}' ) return self._shapely_point.x - + @property def y(self): if not self._crs.startswith("cartesian"): @@ -431,7 +458,7 @@ def y(self): f'Invalid coordinate ("y") for points defined over {self.crs}' ) return self._shapely_point.y - + @property def z(self): if self._crs != "cartesian-3d": @@ -439,7 +466,7 @@ def z(self): f'Invalid coordinate ("z") for points defined over {self.crs}' ) return self._shapely_point.z - + @property def latitude(self): if not self._crs.startswith("wgs-84"): @@ -447,7 +474,7 @@ def latitude(self): f'Invalid coordinate ("latitude") for points defined over {self.crs}' ) return self._shapely_point.y - + @property def longitude(self): if not self._crs.startswith("wgs-84"): @@ -455,7 +482,7 @@ def longitude(self): f'Invalid coordinate ("longitude") for points defined over {self.crs}' ) return self._shapely_point.x - + @property def height(self): if self._crs != "wgs-84-3d": @@ -463,13 +490,13 @@ def height(self): f'Invalid coordinate ("height") for points defined over {self.crs}' ) return self._shapely_point.z - + # The following operations are necessary here due to the way queries (and more importantly their parameters) get # combined and evaluated in neomodel. Specifically, query expressions get duplicated with deep copies and any valid # datatype values should also implement these operations. def __copy__(self): return NeomodelPoint(self) - + def __deepcopy__(self, memo): return NeomodelPoint(self) @@ -484,7 +511,9 @@ def __eq__(self, other): Compare objects by value """ if not isinstance(other, (ShapelyPoint, NeomodelPoint)): - raise ValueException(f"NeomodelPoint equality comparison expected NeomodelPoint or Shapely Point, received {type(other)}") + raise ValueException( + f"NeomodelPoint equality comparison expected NeomodelPoint or Shapely Point, received {type(other)}" + ) else: if isinstance(other, ShapelyPoint): return self.coords[0] == other.coords[0] @@ -517,12 +546,19 @@ def __init__(self, *args, **kwargs): crs = None if crs is None or (crs not in ACCEPTABLE_CRS): - raise ValueError(f"Invalid CRS({crs}). Point properties require CRS to be one of {','.join(ACCEPTABLE_CRS)}") + raise ValueError( + f"Invalid CRS({crs}). Point properties require CRS to be one of {','.join(ACCEPTABLE_CRS)}" + ) # If a default value is passed and it is not a callable, then make sure it is in the right type - if "default" in kwargs and not hasattr(kwargs["default"], "__call__") and not isinstance(kwargs["default"], NeomodelPoint): - raise TypeError(f"Invalid default value. Expected NeomodelPoint, received {type(kwargs['default'])}" - ) + if ( + "default" in kwargs + and not hasattr(kwargs["default"], "__call__") + and not isinstance(kwargs["default"], NeomodelPoint) + ): + raise TypeError( + f"Invalid default value. Expected NeomodelPoint, received {type(kwargs['default'])}" + ) super().__init__(*args, **kwargs) self._crs = crs @@ -544,10 +580,14 @@ def inflate(self, value): try: value_point_crs = SRID_TO_CRS[value.srid] except KeyError as e: - raise ValueError(f"Invalid SRID to inflate. Expected one of {SRID_TO_CRS.keys()}, received {value.srid}") from e + raise ValueError( + f"Invalid SRID to inflate. Expected one of {SRID_TO_CRS.keys()}, received {value.srid}" + ) from e if self._crs != value_point_crs: - raise ValueError(f"Invalid CRS. Expected POINT defined over {self._crs}, received {value_point_crs}") + raise ValueError( + f"Invalid CRS. Expected POINT defined over {self._crs}, received {value_point_crs}" + ) # cartesian if value.srid == 7203: return NeomodelPoint(x=value.x, y=value.y) @@ -581,7 +621,9 @@ def deflate(self, value): ) if value.crs != self._crs: - raise ValueError(f"Invalid CRS. Expected NeomodelPoint defined over {self._crs}, received NeomodelPoint defined over {value.crs}") + raise ValueError( + f"Invalid CRS. Expected NeomodelPoint defined over {self._crs}, received NeomodelPoint defined over {value.crs}" + ) if value.crs == "cartesian-3d": return neo4j.spatial.CartesianPoint((value.x, value.y, value.z)) diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py similarity index 94% rename from neomodel/contrib/semi_structured.py rename to neomodel/contrib/sync_/semi_structured.py index 9c719983..86a5a140 100644 --- a/neomodel/contrib/semi_structured.py +++ b/neomodel/contrib/sync_/semi_structured.py @@ -1,5 +1,5 @@ -from neomodel.core import StructuredNode from neomodel.exceptions import DeflateConflict, InflateConflict +from neomodel.sync_.core import StructuredNode from neomodel.util import _get_node_properties @@ -57,7 +57,7 @@ def inflate(cls, node): def deflate(cls, node_props, obj=None, skip_empty=False): deflated = super().deflate(node_props, obj, skip_empty=skip_empty) for key in [k for k in node_props if k not in deflated]: - if hasattr(cls, key) and (getattr(cls,key).required or not skip_empty): + if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty): raise DeflateConflict(cls, key, deflated[key], obj.element_id) node_props.update(deflated) diff --git a/neomodel/core.py b/neomodel/core.py deleted file mode 100644 index 415a97af..00000000 --- a/neomodel/core.py +++ /dev/null @@ -1,784 +0,0 @@ -import sys -import warnings -from itertools import combinations - -from neo4j.exceptions import ClientError - -from neomodel import config -from neomodel.exceptions import ( - DoesNotExist, - FeatureNotSupported, - NodeClassAlreadyDefined, -) -from neomodel.hooks import hooks -from neomodel.properties import Property, PropertyManager -from neomodel.util import Database, _get_node_properties, _UnsavedNode, classproperty - -db = Database() - -RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" -INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" -CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" -STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" - - -def drop_constraints(quiet=True, stdout=None): - """ - Discover and drop all constraints. - - :type: bool - :return: None - """ - if not stdout or stdout is None: - stdout = sys.stdout - - results, meta = db.cypher_query("SHOW CONSTRAINTS") - - results_as_dict = [dict(zip(meta, row)) for row in results] - for constraint in results_as_dict: - db.cypher_query("DROP CONSTRAINT " + constraint["name"]) - if not quiet: - stdout.write( - ( - " - Dropping unique constraint and index" - f" on label {constraint['labelsOrTypes'][0]}" - f" with property {constraint['properties'][0]}.\n" - ) - ) - if not quiet: - stdout.write("\n") - - -def drop_indexes(quiet=True, stdout=None): - """ - Discover and drop all indexes, except the automatically created token lookup indexes. - - :type: bool - :return: None - """ - if not stdout or stdout is None: - stdout = sys.stdout - - indexes = db.list_indexes(exclude_token_lookup=True) - for index in indexes: - db.cypher_query("DROP INDEX " + index["name"]) - if not quiet: - stdout.write( - f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' - ) - if not quiet: - stdout.write("\n") - - -def remove_all_labels(stdout=None): - """ - Calls functions for dropping constraints and indexes. - - :param stdout: output stream - :return: None - """ - - if not stdout: - stdout = sys.stdout - - stdout.write("Dropping constraints...\n") - drop_constraints(quiet=False, stdout=stdout) - - stdout.write("Dropping indexes...\n") - drop_indexes(quiet=False, stdout=stdout) - - -def install_labels(cls, quiet=True, stdout=None): - """ - Setup labels with indexes and constraints for a given class - - :param cls: StructuredNode class - :type: class - :param quiet: (default true) enable standard output - :param stdout: stdout stream - :type: bool - :return: None - """ - if not stdout or stdout is None: - stdout = sys.stdout - - if not hasattr(cls, "__label__"): - if not quiet: - stdout.write( - f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" - ) - return - - for name, property in cls.defined_properties(aliases=False, rels=False).items(): - _install_node(cls, name, property, quiet, stdout) - - for _, relationship in cls.defined_properties( - aliases=False, rels=True, properties=False - ).items(): - _install_relationship(cls, relationship, quiet, stdout) - - -def _create_node_index(label: str, property_name: str, stdout): - try: - db.cypher_query( - f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - INDEX_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - - -def _create_node_constraint(label: str, property_name: str, stdout): - try: - db.cypher_query( - f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} - FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - CONSTRAINT_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - - -def _create_relationship_index(relationship_type: str, property_name: str, stdout): - try: - db.cypher_query( - f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - INDEX_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - - -def _create_relationship_constraint(relationship_type: str, property_name: str, stdout): - if db.version_is_higher_than("5.7"): - try: - db.cypher_query( - f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} - FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" - ) - except ClientError as e: - if e.code in ( - RULE_ALREADY_EXISTS, - CONSTRAINT_ALREADY_EXISTS, - ): - stdout.write(f"{str(e)}\n") - else: - raise - else: - raise FeatureNotSupported( - f"Unique indexes on relationships are not supported in Neo4j version {db.database_version}. Please upgrade to Neo4j 5.7 or higher." - ) - - -def _install_node(cls, name, property, quiet, stdout): - # Create indexes and constraints for node property - db_property = property.db_property or name - if property.index: - if not quiet: - stdout.write( - f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" - ) - _create_node_index( - label=cls.__label__, property_name=db_property, stdout=stdout - ) - - elif property.unique_index: - if not quiet: - stdout.write( - f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" - ) - _create_node_constraint( - label=cls.__label__, property_name=db_property, stdout=stdout - ) - - -def _install_relationship(cls, relationship, quiet, stdout): - # Create indexes and constraints for relationship property - relationship_cls = relationship.definition["model"] - if relationship_cls is not None: - relationship_type = relationship.definition["relation_type"] - for prop_name, property in relationship_cls.defined_properties( - aliases=False, rels=False - ).items(): - db_property = property.db_property or prop_name - if property.index: - if not quiet: - stdout.write( - f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" - ) - _create_relationship_index( - relationship_type=relationship_type, - property_name=db_property, - stdout=stdout, - ) - elif property.unique_index: - if not quiet: - stdout.write( - f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" - ) - _create_relationship_constraint( - relationship_type=relationship_type, - property_name=db_property, - stdout=stdout, - ) - - -def install_all_labels(stdout=None): - """ - Discover all subclasses of StructuredNode in your application and execute install_labels on each. - Note: code must be loaded (imported) in order for a class to be discovered. - - :param stdout: output stream - :return: None - """ - - if not stdout or stdout is None: - stdout = sys.stdout - - def subsub(cls): # recursively return all subclasses - subclasses = cls.__subclasses__() - if not subclasses: # base case: no more subclasses - return [] - return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] - - stdout.write("Setting up indexes and constraints...\n\n") - - i = 0 - for cls in subsub(StructuredNode): - stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") - install_labels(cls, quiet=False, stdout=stdout) - i += 1 - - if i: - stdout.write("\n") - - stdout.write(f"Finished {i} classes.\n") - - -class NodeMeta(type): - def __new__(mcs, name, bases, namespace): - namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) - cls = super().__new__(mcs, name, bases, namespace) - cls.DoesNotExist._model_class = cls - - if hasattr(cls, "__abstract_node__"): - delattr(cls, "__abstract_node__") - else: - if "deleted" in namespace: - raise ValueError( - "Property name 'deleted' is not allowed as it conflicts with neomodel internals." - ) - elif "id" in namespace: - raise ValueError( - """ - Property name 'id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as id is also a Neo4j internal. - """ - ) - elif "element_id" in namespace: - raise ValueError( - """ - Property name 'element_id' is not allowed as it conflicts with neomodel internals. - Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. - """ - ) - for key, value in ( - (x, y) for x, y in namespace.items() if isinstance(y, Property) - ): - value.name, value.owner = key, cls - if hasattr(value, "setup") and callable(value.setup): - value.setup() - - # cache various groups of properies - cls.__required_properties__ = tuple( - name - for name, property in cls.defined_properties( - aliases=False, rels=False - ).items() - if property.required or property.unique_index - ) - cls.__all_properties__ = tuple( - cls.defined_properties(aliases=False, rels=False).items() - ) - cls.__all_aliases__ = tuple( - cls.defined_properties(properties=False, rels=False).items() - ) - cls.__all_relationships__ = tuple( - cls.defined_properties(aliases=False, properties=False).items() - ) - - cls.__label__ = namespace.get("__label__", name) - cls.__optional_labels__ = namespace.get("__optional_labels__", []) - - if config.AUTO_INSTALL_LABELS: - install_labels(cls, quiet=False) - - build_class_registry(cls) - - return cls - - -def build_class_registry(cls): - base_label_set = frozenset(cls.inherited_labels()) - optional_label_set = set(cls.inherited_optional_labels()) - - # Construct all possible combinations of labels + optional labels - possible_label_combinations = [ - frozenset(set(x).union(base_label_set)) - for i in range(1, len(optional_label_set) + 1) - for x in combinations(optional_label_set, i) - ] - possible_label_combinations.append(base_label_set) - - for label_set in possible_label_combinations: - if label_set not in db._NODE_CLASS_REGISTRY: - db._NODE_CLASS_REGISTRY[label_set] = cls - else: - raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY) - - -NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) - - -class StructuredNode(NodeBase): - """ - Base class for all node definitions to inherit from. - - If you want to create your own abstract classes set: - __abstract_node__ = True - """ - - # static properties - - __abstract_node__ = True - - # magic methods - - def __init__(self, *args, **kwargs): - if "deleted" in kwargs: - raise ValueError("deleted property is reserved for neomodel") - - for key, val in self.__all_relationships__: - self.__dict__[key] = val.build_manager(self, key) - - super().__init__(*args, **kwargs) - - def __eq__(self, other): - if not isinstance(other, (StructuredNode,)): - return False - if hasattr(self, "element_id") and hasattr(other, "element_id"): - return self.element_id == other.element_id - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return f"<{self.__class__.__name__}: {self}>" - - def __str__(self): - return repr(self.__properties__) - - # dynamic properties - - @classproperty - def nodes(cls): - """ - Returns a NodeSet object representing all nodes of the classes label - :return: NodeSet - :rtype: NodeSet - """ - from .match import NodeSet - - return NodeSet(cls) - - @property - def element_id(self): - if hasattr(self, "element_id_property"): - return ( - int(self.element_id_property) - if db.database_version.startswith("4") - else self.element_id_property - ) - return None - - # Version 4.4 support - id is deprecated in version 5.x - @property - def id(self): - try: - return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) - - # methods - - @classmethod - def _build_merge_query( - cls, merge_params, update_existing=False, lazy=False, relationship=None - ): - """ - Get a tuple of a CYPHER query and a params dict for the specified MERGE query. - - :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". - :type merge_params: list of dict - :param update_existing: True to update properties of existing nodes, default False to keep existing values. - :type update_existing: bool - :rtype: tuple - """ - query_params = dict(merge_params=merge_params) - n_merge_labels = ":".join(cls.inherited_labels()) - n_merge_prm = ", ".join( - ( - f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" - for p in cls.__required_properties__ - ) - ) - n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" - if relationship is None: - # create "simple" unwind query - query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " - else: - # validate relationship - if not isinstance(relationship.source, StructuredNode): - raise ValueError( - f"relationship source [{repr(relationship.source)}] is not a StructuredNode" - ) - relation_type = relationship.definition.get("relation_type") - if not relation_type: - raise ValueError( - "No relation_type is specified on provided relationship" - ) - - from .match import _rel_helper - - query_params["source_id"] = relationship.source.element_id - query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " - query += "WITH source\n UNWIND $merge_params as params \n " - query += "MERGE " - query += _rel_helper( - lhs="source", - rhs=n_merge, - ident=None, - relation_type=relation_type, - direction=relationship.definition["direction"], - ) - - query += "ON CREATE SET n = params.create\n " - # if update_existing, write properties on match as well - if update_existing is True: - query += "ON MATCH SET n += params.update\n" - - # close query - if lazy: - query += f"RETURN {db.get_id_method()}(n)" - else: - query += "RETURN n" - - return query, query_params - - @classmethod - def create(cls, *props, **kwargs): - """ - Call to CREATE with parameters map. A new instance will be created and saved. - - :param props: dict of properties to create the nodes. - :type props: tuple - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :type: bool - :rtype: list - """ - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - lazy = kwargs.get("lazy", False) - # create mapped query - query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" - - # close query - if lazy: - query += f" RETURN {db.get_id_method()}(n)" - else: - query += " RETURN n" - - results = [] - for item in [ - cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props - ]: - node, _ = db.cypher_query(query, {"create_params": item}) - results.extend(node[0]) - - nodes = [cls.inflate(node) for node in results] - - if not lazy and hasattr(cls, "post_create"): - for node in nodes: - node.post_create() - - return nodes - - @classmethod - def create_or_update(cls, *props, **kwargs): - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, - this is an atomic operation. If an instance already exists all optional properties specified will be updated. - - Note that the post_create hook isn't called after create_or_update - - :param props: List of dict arguments to get or create the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy = kwargs.get("lazy", False) - relationship = kwargs.get("relationship") - - # build merge query, make sure to update only explicitly specified properties - create_or_update_params = [] - for specified, deflated in [ - (p, cls.deflate(p, skip_empty=True)) for p in props - ]: - create_or_update_params.append( - { - "create": deflated, - "update": dict( - (k, v) for k, v in deflated.items() if k in specified - ), - } - ) - query, params = cls._build_merge_query( - create_or_update_params, - update_existing=True, - relationship=relationship, - lazy=lazy, - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = db.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - def cypher(self, query, params=None): - """ - Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. - - :param query: cypher query string - :type: string - :param params: query parameters - :type: dict - :return: list containing query results - :rtype: list - """ - self._pre_action_check("cypher") - params = params or {} - params.update({"self": self.element_id}) - return db.cypher_query(query, params) - - @hooks - def delete(self): - """ - Delete a node and its relationships - - :return: True - """ - self._pre_action_check("delete") - self.cypher( - f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self" - ) - delattr(self, "element_id_property") - self.deleted = True - return True - - @classmethod - def get_or_create(cls, *props, **kwargs): - """ - Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, - this is an atomic operation. - Parameters must contain all required properties, any non required properties with defaults will be generated. - - Note that the post_create hook isn't called after get_or_create - - :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create - the entities with. - :type props: tuple - :param relationship: Optional, relationship to get/create on when new entity is created. - :param lazy: False by default, specify True to get nodes with id only without the parameters. - :rtype: list - """ - lazy = kwargs.get("lazy", False) - relationship = kwargs.get("relationship") - - # build merge query - get_or_create_params = [ - {"create": cls.deflate(p, skip_empty=True)} for p in props - ] - query, params = cls._build_merge_query( - get_or_create_params, relationship=relationship, lazy=lazy - ) - - if "streaming" in kwargs: - warnings.warn( - STREAMING_WARNING, - category=DeprecationWarning, - stacklevel=1, - ) - - # fetch and build instance for each result - results = db.cypher_query(query, params) - return [cls.inflate(r[0]) for r in results[0]] - - @classmethod - def inflate(cls, node): - """ - Inflate a raw neo4j_driver node to a neomodel node - :param node: - :return: node object - """ - # support lazy loading - if isinstance(node, str) or isinstance(node, int): - snode = cls() - snode.element_id_property = node - else: - node_properties = _get_node_properties(node) - props = {} - for key, prop in cls.__all_properties__: - # map property name from database to object property - db_property = prop.db_property or key - - if db_property in node_properties: - props[key] = prop.inflate(node_properties[db_property], node) - elif prop.has_default: - props[key] = prop.default_value() - else: - props[key] = None - - snode = cls(**props) - snode.element_id_property = node.element_id - - return snode - - @classmethod - def inherited_labels(cls): - """ - Return list of labels from nodes class hierarchy. - - :return: list - """ - return [ - scls.__label__ - for scls in cls.mro() - if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") - ] - - @classmethod - def inherited_optional_labels(cls): - """ - Return list of optional labels from nodes class hierarchy. - - :return: list - :rtype: list - """ - return [ - label - for scls in cls.mro() - for label in getattr(scls, "__optional_labels__", []) - if not hasattr(scls, "__abstract_node__") - ] - - def labels(self): - """ - Returns list of labels tied to the node from neo4j. - - :return: list of labels - :rtype: list - """ - self._pre_action_check("labels") - return self.cypher( - f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" - )[0][0][0] - - def _pre_action_check(self, action): - if hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on deleted node" - ) - if not hasattr(self, "element_id"): - raise ValueError( - f"{self.__class__.__name__}.{action}() attempted on unsaved node" - ) - - def refresh(self): - """ - Reload the node from neo4j - """ - self._pre_action_check("refresh") - if hasattr(self, "element_id"): - request = self.cypher( - f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" - )[0] - if not request or not request[0]: - raise self.__class__.DoesNotExist("Can't refresh non existent node") - node = self.inflate(request[0][0]) - for key, val in node.__properties__.items(): - setattr(self, key, val) - else: - raise ValueError("Can't refresh unsaved node") - - @hooks - def save(self): - """ - Save the node to neo4j or raise an exception - - :return: the node instance - """ - - # create or update instance node - if hasattr(self, "element_id_property"): - # update - params = self.deflate(self.__properties__, self) - query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n" - - if params: - query += "SET " - query += ",\n".join([f"n.{key} = ${key}" for key in params]) - query += "\n" - if self.inherited_labels(): - query += "\n".join( - [f"SET n:`{label}`" for label in self.inherited_labels()] - ) - self.cypher(query, params) - elif hasattr(self, "deleted") and self.deleted: - raise ValueError( - f"{self.__class__.__name__}.save() attempted on deleted node" - ) - else: # create - created_node = self.create(self.__properties__)[0] - self.element_id_property = created_node.element_id - return self diff --git a/neomodel/integration/numpy.py b/neomodel/integration/numpy.py index a04508c4..14bae3df 100644 --- a/neomodel/integration/numpy.py +++ b/neomodel/integration/numpy.py @@ -7,7 +7,7 @@ Example: - >>> from neomodel import db + >>> from neomodel.async_ import db >>> from neomodel.integration.numpy import to_nparray >>> db.set_connection('bolt://neo4j:secret@localhost:7687') >>> df = to_nparray(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name")) diff --git a/neomodel/integration/pandas.py b/neomodel/integration/pandas.py index 1ad19871..2f809ade 100644 --- a/neomodel/integration/pandas.py +++ b/neomodel/integration/pandas.py @@ -7,7 +7,7 @@ Example: - >>> from neomodel import db + >>> from neomodel.async_ import db >>> from neomodel.integration.pandas import to_dataframe >>> db.set_connection('bolt://neo4j:secret@localhost:7687') >>> df = to_dataframe(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name")) diff --git a/neomodel/match_q.py b/neomodel/match_q.py index 7e76a23b..4e45588c 100644 --- a/neomodel/match_q.py +++ b/neomodel/match_q.py @@ -69,7 +69,11 @@ def _new_instance(cls, children=None, connector=None, negated=False): return obj def __str__(self): - return f"(NOT ({self.connector}: {', '.join(str(c) for c in self.children)}))" if self.negated else f"({self.connector}: {', '.join(str(c) for c in self.children)})" + return ( + f"(NOT ({self.connector}: {', '.join(str(c) for c in self.children)}))" + if self.negated + else f"({self.connector}: {', '.join(str(c) for c in self.children)})" + ) def __repr__(self): return f"<{self.__class__.__name__}: {self}>" diff --git a/neomodel/properties.py b/neomodel/properties.py index e28b4ead..cac92dfe 100644 --- a/neomodel/properties.py +++ b/neomodel/properties.py @@ -2,7 +2,6 @@ import json import re import sys -import types import uuid from datetime import date, datetime @@ -10,117 +9,12 @@ import pytz from neomodel import config -from neomodel.exceptions import DeflateError, InflateError, RequiredProperty +from neomodel.exceptions import DeflateError, InflateError if sys.version_info >= (3, 0): Unicode = str -def display_for(key): - def display_choice(self): - return getattr(self.__class__, key).choices[getattr(self, key)] - - return display_choice - - -class PropertyManager: - """ - Common methods for handling properties on node and relationship objects. - """ - - def __init__(self, **kwargs): - properties = getattr(self, "__all_properties__", None) - if properties is None: - properties = self.defined_properties(rels=False, aliases=False).items() - for name, property in properties: - if kwargs.get(name) is None: - if getattr(property, "has_default", False): - setattr(self, name, property.default_value()) - else: - setattr(self, name, None) - else: - setattr(self, name, kwargs[name]) - - if getattr(property, "choices", None): - setattr( - self, - f"get_{name}_display", - types.MethodType(display_for(name), self), - ) - - if name in kwargs: - del kwargs[name] - - aliases = getattr(self, "__all_aliases__", None) - if aliases is None: - aliases = self.defined_properties( - aliases=True, rels=False, properties=False - ).items() - for name, property in aliases: - if name in kwargs: - setattr(self, name, kwargs[name]) - del kwargs[name] - - # undefined properties (for magic @prop.setters etc) - for name, property in kwargs.items(): - setattr(self, name, property) - - @property - def __properties__(self): - from .relationship_manager import RelationshipManager - - return dict( - (name, value) - for name, value in vars(self).items() - if not name.startswith("_") - and not callable(value) - and not isinstance( - value, - ( - RelationshipManager, - AliasProperty, - ), - ) - ) - - @classmethod - def deflate(cls, properties, obj=None, skip_empty=False): - # deflate dict ready to be stored - deflated = {} - for name, property in cls.defined_properties(aliases=False, rels=False).items(): - db_property = property.db_property or name - if properties.get(name) is not None: - deflated[db_property] = property.deflate(properties[name], obj) - elif property.has_default: - deflated[db_property] = property.deflate(property.default_value(), obj) - elif property.required: - raise RequiredProperty(name, cls) - elif not skip_empty: - deflated[db_property] = None - return deflated - - @classmethod - def defined_properties(cls, aliases=True, properties=True, rels=True): - from .relationship_manager import RelationshipDefinition - - props = {} - for baseclass in reversed(cls.__mro__): - props.update( - dict( - (name, property) - for name, property in vars(baseclass).items() - if (aliases and isinstance(property, AliasProperty)) - or ( - properties - and isinstance(property, Property) - and not isinstance(property, AliasProperty) - ) - or (rels and isinstance(property, RelationshipDefinition)) - ) - ) - return props - - def validator(fn): fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__ if fn_name == "inflate": @@ -467,7 +361,7 @@ def deflate(self, value): class DateTimeFormatProperty(Property): """ - Store a datetime by custome format + Store a datetime by custom format :param default_now: If ``True``, the creation time (Local) will be used as default. Defaults to ``False``. :param format: Date format string, default is %Y-%m-%d diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py index a8cf4c0e..3147ebdf 100644 --- a/neomodel/scripts/neomodel_inspect_database.py +++ b/neomodel/scripts/neomodel_inspect_database.py @@ -1,7 +1,7 @@ """ .. _neomodel_inspect_database: -``_neomodel_inspect_database`` +``neomodel_inspect_database`` --------------------------- :: @@ -17,6 +17,8 @@ If a file is specified, the tool will write the class definitions to that file. If no file is specified, the tool will print the class definitions to stdout. + + Note : this script only has a synchronous mode. options: -h, --help show this help message and exit @@ -33,7 +35,7 @@ import textwrap from os import environ -from neomodel import db +from neomodel.sync_.core import db IMPORTS = [] diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py index 8bd5119f..8aa7a73b 100755 --- a/neomodel/scripts/neomodel_install_labels.py +++ b/neomodel/scripts/neomodel_install_labels.py @@ -14,6 +14,8 @@ If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + + Note : this script only has a synchronous mode. positional arguments: @@ -32,7 +34,7 @@ from importlib import import_module from os import environ, path -from .. import db, install_all_labels +from neomodel.sync_.core import db def load_python_module_or_file(name): @@ -111,7 +113,7 @@ def main(): print(f"Connecting to {bolt_url}") db.set_connection(url=bolt_url) - install_all_labels() + db.install_all_labels() if __name__ == "__main__": diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py index 1ad6cc34..79e79390 100755 --- a/neomodel/scripts/neomodel_remove_labels.py +++ b/neomodel/scripts/neomodel_remove_labels.py @@ -14,6 +14,8 @@ If a connection URL is not specified, the tool will look up the environment variable NEO4J_BOLT_URL. If that environment variable is not set, the tool will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687 + + Note : this script only has a synchronous mode. options: -h, --help show this help message and exit @@ -27,7 +29,7 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from os import environ -from .. import db, remove_all_labels +from neomodel.sync_.core import db def main(): @@ -63,7 +65,7 @@ def main(): print(f"Connecting to {bolt_url}") db.set_connection(url=bolt_url) - remove_all_labels() + db.remove_all_labels() if __name__ == "__main__": diff --git a/neomodel/sync_/__init__.py b/neomodel/sync_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/neomodel/cardinality.py b/neomodel/sync_/cardinality.py similarity index 95% rename from neomodel/cardinality.py rename to neomodel/sync_/cardinality.py index 099bf578..716d173f 100644 --- a/neomodel/cardinality.py +++ b/neomodel/sync_/cardinality.py @@ -1,5 +1,5 @@ from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation -from neomodel.relationship_manager import ( # pylint:disable=unused-import +from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import RelationshipManager, ZeroOrMore, ) @@ -37,7 +37,7 @@ def connect(self, node, properties=None): :type: dict :return: True / rel instance """ - if len(self): + if super().__len__(): raise AttemptedCardinalityViolation( f"Node already has {self} can't connect more" ) @@ -130,6 +130,6 @@ def connect(self, node, properties=None): """ if not hasattr(self.source, "element_id") or self.source.element_id is None: raise ValueError("Node has not been saved cannot connect!") - if len(self): + if super().__len__(): raise AttemptedCardinalityViolation("Node already has one relationship") return super().connect(node, properties) diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py new file mode 100644 index 00000000..8778adb0 --- /dev/null +++ b/neomodel/sync_/core.py @@ -0,0 +1,1519 @@ +import logging +import os +import sys +import time +import warnings +from asyncio import iscoroutinefunction +from itertools import combinations +from threading import local +from typing import Optional, Sequence +from urllib.parse import quote, unquote, urlparse + +from neo4j import ( + DEFAULT_DATABASE, + Driver, + GraphDatabase, + Result, + Session, + Transaction, + basic_auth, +) +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired +from neo4j.graph import Node, Path, Relationship + +from neomodel import config +from neomodel._async_compat.util import Util +from neomodel.exceptions import ( + ConstraintValidationFailed, + DoesNotExist, + FeatureNotSupported, + NodeClassAlreadyDefined, + NodeClassNotDefined, + RelationshipClassNotDefined, + UniqueProperty, +) +from neomodel.hooks import hooks +from neomodel.properties import Property +from neomodel.sync_.property_manager import PropertyManager +from neomodel.util import ( + _get_node_properties, + _UnsavedNode, + classproperty, + deprecated, + version_tag_to_integer, +) + +logger = logging.getLogger(__name__) + +RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists" +INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists" +CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists" +STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg" +NOT_COROUTINE_ERROR = "The decorated function must be a coroutine" + + +# make sure the connection url has been set prior to executing the wrapped function +def ensure_connection(func): + """Decorator that ensures a connection is established before executing the decorated function. + + Args: + func (callable): The function to be decorated. + + Returns: + callable: The decorated function. + + """ + + def wrapper(self, *args, **kwargs): + # Sort out where to find url + if hasattr(self, "db"): + _db = self.db + else: + _db = self + + if not _db.driver: + if hasattr(config, "DATABASE_URL") and config.DATABASE_URL: + _db.set_connection(url=config.DATABASE_URL) + elif hasattr(config, "DRIVER") and config.DRIVER: + _db.set_connection(driver=config.DRIVER) + + return func(self, *args, **kwargs) + + return wrapper + + +class Database(local): + """ + A singleton object via which all operations from neomodel to the Neo4j backend are handled with. + """ + + _NODE_CLASS_REGISTRY = {} + + def __init__(self): + self._active_transaction = None + self.url = None + self.driver = None + self._session = None + self._pid = None + self._database_name = DEFAULT_DATABASE + self.protocol_version = None + self._database_version = None + self._database_edition = None + self.impersonated_user = None + + def set_connection(self, url: str = None, driver: Driver = None): + """ + Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. + + Args: + url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname. + When provided, a Neo4j driver instance will be created by neomodel. + + driver (neo4j.Driver): Optionally, a pre-created driver instance. + When provided, neomodel will not create a driver instance but use this one instead. + """ + if driver: + self.driver = driver + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + elif url: + self._parse_driver_from_url(url=url) + + self._pid = os.getpid() + self._active_transaction = None + # Set to default database if it hasn't been set before + if self._database_name is None: + self._database_name = DEFAULT_DATABASE + + # Getting the information about the database version requires a connection to the database + self._database_version = None + self._database_edition = None + self._update_database_version() + + def _parse_driver_from_url(self, url: str) -> None: + """Parse the driver information from the given URL and initialize the driver. + + Args: + url (str): The URL to parse. + + Raises: + ValueError: If the URL format is not as expected. + + Returns: + None - Sets the driver and database_name as class properties + """ + p_start = url.replace(":", "", 1).find(":") + 2 + p_end = url.rfind("@") + password = url[p_start:p_end] + url = url.replace(password, quote(password)) + parsed_url = urlparse(url) + + valid_schemas = [ + "bolt", + "bolt+s", + "bolt+ssc", + "bolt+routing", + "neo4j", + "neo4j+s", + "neo4j+ssc", + ] + + if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas: + credentials, hostname = parsed_url.netloc.rsplit("@", 1) + username, password = credentials.split(":") + password = unquote(password) + database_name = parsed_url.path.strip("/") + else: + raise ValueError( + f"Expecting url format: bolt://user:password@localhost:7687 got {url}" + ) + + options = { + "auth": basic_auth(username, password), + "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, + "connection_timeout": config.CONNECTION_TIMEOUT, + "keep_alive": config.KEEP_ALIVE, + "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, + "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, + "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, + "resolver": config.RESOLVER, + "user_agent": config.USER_AGENT, + } + + if "+s" not in parsed_url.scheme: + options["encrypted"] = config.ENCRYPTED + options["trusted_certificates"] = config.TRUSTED_CERTIFICATES + + self.driver = GraphDatabase.driver( + parsed_url.scheme + "://" + hostname, **options + ) + self.url = url + # The database name can be provided through the url or the config + if database_name == "": + if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: + self._database_name = config.DATABASE_NAME + else: + self._database_name = database_name + + def close_connection(self): + """ + Closes the currently open driver. + The driver should always be closed at the end of the application's lifecyle. + """ + self._database_version = None + self._database_edition = None + self._database_name = None + self.driver.close() + self.driver = None + + @property + def database_version(self): + if self._database_version is None: + self._update_database_version() + + return self._database_version + + @property + def database_edition(self): + if self._database_edition is None: + self._update_database_version() + + return self._database_edition + + @property + def transaction(self): + """ + Returns the current transaction object + """ + return TransactionProxy(self) + + @property + def write_transaction(self): + return TransactionProxy(self, access_mode="WRITE") + + @property + def read_transaction(self): + return TransactionProxy(self, access_mode="READ") + + def impersonate(self, user: str) -> "ImpersonationHandler": + """All queries executed within this context manager will be executed as impersonated user + + Args: + user (str): User to impersonate + + Returns: + ImpersonationHandler: Context manager to set/unset the user to impersonate + """ + db_edition = self.database_edition + if db_edition != "enterprise": + raise FeatureNotSupported( + "Impersonation is only available in Neo4j Enterprise edition" + ) + return ImpersonationHandler(self, impersonated_user=user) + + @ensure_connection + def begin(self, access_mode=None, **parameters): + """ + Begins a new transaction. Raises SystemError if a transaction is already active. + """ + if ( + hasattr(self, "_active_transaction") + and self._active_transaction is not None + ): + raise SystemError("Transaction in progress") + self._session: Session = self.driver.session( + default_access_mode=access_mode, + database=self._database_name, + impersonated_user=self.impersonated_user, + **parameters, + ) + self._active_transaction: Transaction = self._session.begin_transaction() + + @ensure_connection + def commit(self): + """ + Commits the current transaction and closes its session + + :return: last_bookmarks + """ + try: + self._active_transaction.commit() + last_bookmarks: Bookmarks = self._session.last_bookmarks() + finally: + # In case when something went wrong during + # committing changes to the database + # we have to close an active transaction and session. + self._active_transaction.close() + self._session.close() + self._active_transaction = None + self._session = None + + return last_bookmarks + + @ensure_connection + def rollback(self): + """ + Rolls back the current transaction and closes its session + """ + try: + self._active_transaction.rollback() + finally: + # In case when something went wrong during changes rollback, + # we have to close an active transaction and session + self._active_transaction.close() + self._session.close() + self._active_transaction = None + self._session = None + + def _update_database_version(self): + """ + Updates the database server information when it is required + """ + try: + results = self.cypher_query( + "CALL dbms.components() yield versions, edition return versions[0], edition" + ) + self._database_version = results[0][0][0] + self._database_edition = results[0][0][1] + except ServiceUnavailable: + # The database server is not running yet + pass + + def _object_resolution(self, object_to_resolve): + """ + Performs in place automatic object resolution on a result + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures and Path objects. Not meant to be called + directly, used primarily by _result_resolution. + + :param object_to_resolve: A result as returned by cypher_query. + :type Any: + + :return: An instantiated object. + """ + # Below is the original comment that came with the code extracted in + # this method. It is not very clear but I decided to keep it just in + # case + # + # + # For some reason, while the type of `a_result_attribute[1]` + # as reported by the neo4j driver is `Node` for Node-type data + # retrieved from the database. + # When the retrieved data are Relationship-Type, + # the returned type is `abc.[REL_LABEL]` which is however + # a descendant of Relationship. + # Consequently, the type checking was changed for both + # Node, Relationship objects + if isinstance(object_to_resolve, Node): + return self._NODE_CLASS_REGISTRY[ + frozenset(object_to_resolve.labels) + ].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Relationship): + rel_type = frozenset([object_to_resolve.type]) + return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) + + if isinstance(object_to_resolve, Path): + from neomodel.sync_.path import NeomodelPath + + return NeomodelPath(object_to_resolve) + + if isinstance(object_to_resolve, list): + return self._result_resolution([object_to_resolve]) + + return object_to_resolve + + def _result_resolution(self, result_list): + """ + Performs in place automatic object resolution on a set of results + returned by cypher_query. + + The function operates recursively in order to be able to resolve Nodes + within nested list structures. Not meant to be called directly, + used primarily by cypher_query. + + :param result_list: A list of results as returned by cypher_query. + :type list: + + :return: A list of instantiated objects. + """ + + # Object resolution occurs in-place + for a_result_item in enumerate(result_list): + for a_result_attribute in enumerate(a_result_item[1]): + try: + # Primitive types should remain primitive types, + # Nodes to be resolved to native objects + resolved_object = a_result_attribute[1] + + resolved_object = self._object_resolution(resolved_object) + + result_list[a_result_item[0]][ + a_result_attribute[0] + ] = resolved_object + + except KeyError as exc: + # Not being able to match the label set of a node with a known object results + # in a KeyError in the internal dictionary used for resolution. If it is impossible + # to match, then raise an exception with more details about the error. + if isinstance(a_result_attribute[1], Node): + raise NodeClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + if isinstance(a_result_attribute[1], Relationship): + raise RelationshipClassNotDefined( + a_result_attribute[1], self._NODE_CLASS_REGISTRY + ) from exc + + return result_list + + @ensure_connection + def cypher_query( + self, + query, + params=None, + handle_unique=True, + retry_on_session_expire=False, + resolve_objects=False, + ): + """ + Runs a query on the database and returns a list of results and their headers. + + :param query: A CYPHER query + :type: str + :param params: Dictionary of parameters + :type: dict + :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors + :type: bool + :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired. + If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance. + :type: bool + :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically + :type: bool + + :return: A tuple containing a list of results and a tuple of headers. + """ + + if self._active_transaction: + # Use current session is a transaction is currently active + results, meta = self._run_cypher_query( + self._active_transaction, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + else: + # Otherwise create a new session in a with to dispose of it after it has been run + with self.driver.session( + database=self._database_name, impersonated_user=self.impersonated_user + ) as session: + results, meta = self._run_cypher_query( + session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ) + + return results, meta + + def _run_cypher_query( + self, + session: Session, + query, + params, + handle_unique, + retry_on_session_expire, + resolve_objects, + ): + try: + # Retrieve the data + start = time.time() + response: Result = session.run(query, params) + results, meta = [list(r.values()) for r in response], response.keys() + end = time.time() + + if resolve_objects: + # Do any automatic resolution required + results = self._result_resolution(results) + + except ClientError as e: + if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": + if "already exists with label" in e.message and handle_unique: + raise UniqueProperty(e.message) from e + + raise ConstraintValidationFailed(e.message) from e + exc_info = sys.exc_info() + raise exc_info[1].with_traceback(exc_info[2]) + except SessionExpired: + if retry_on_session_expire: + self.set_connection(url=self.url) + return self.cypher_query( + query=query, + params=params, + handle_unique=handle_unique, + retry_on_session_expire=False, + ) + raise + + tte = end - start + if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float( + os.environ.get("NEOMODEL_SLOW_QUERIES", 0) + ): + logger.debug( + "query: " + + query + + "\nparams: " + + repr(params) + + f"\ntook: {tte:.2g}s\n" + ) + + return results, meta + + def get_id_method(self) -> str: + db_version = self.database_version + if db_version.startswith("4"): + return "id" + else: + return "elementId" + + def parse_element_id(self, element_id: str): + db_version = self.database_version + return int(element_id) if db_version.startswith("4") else element_id + + def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: + """Returns all indexes existing in the database + + Arguments: + exclude_token_lookup[bool]: Exclude automatically create token lookup indexes + + Returns: + Sequence[dict]: List of dictionaries, each entry being an index definition + """ + indexes, meta_indexes = self.cypher_query("SHOW INDEXES") + indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] + + if exclude_token_lookup: + indexes_as_dict = [ + obj for obj in indexes_as_dict if obj["type"] != "LOOKUP" + ] + + return indexes_as_dict + + def list_constraints(self) -> Sequence[dict]: + """Returns all constraints existing in the database + + Returns: + Sequence[dict]: List of dictionaries, each entry being a constraint definition + """ + constraints, meta_constraints = self.cypher_query("SHOW CONSTRAINTS") + constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] + + return constraints_as_dict + + @ensure_connection + def version_is_higher_than(self, version_tag: str) -> bool: + """Returns true if the database version is higher or equal to a given tag + + Args: + version_tag (str): The version to compare against + + Returns: + bool: True if the database version is higher or equal to the given version + """ + db_version = self.database_version + return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag) + + @ensure_connection + def edition_is_enterprise(self) -> bool: + """Returns true if the database edition is enterprise + + Returns: + bool: True if the database edition is enterprise + """ + edition = self.database_edition + return edition == "enterprise" + + def change_neo4j_password(self, user, new_password): + self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") + + def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False): + self.cypher_query( + """ + MATCH (a) + CALL { WITH a DETACH DELETE a } + IN TRANSACTIONS OF 5000 rows + """ + ) + if clear_constraints: + drop_constraints() + if clear_indexes: + drop_indexes() + + def drop_constraints(self, quiet=True, stdout=None): + """ + Discover and drop all constraints. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + results, meta = self.cypher_query("SHOW CONSTRAINTS") + + results_as_dict = [dict(zip(meta, row)) for row in results] + for constraint in results_as_dict: + self.cypher_query("DROP CONSTRAINT " + constraint["name"]) + if not quiet: + stdout.write( + ( + " - Dropping unique constraint and index" + f" on label {constraint['labelsOrTypes'][0]}" + f" with property {constraint['properties'][0]}.\n" + ) + ) + if not quiet: + stdout.write("\n") + + def drop_indexes(self, quiet=True, stdout=None): + """ + Discover and drop all indexes, except the automatically created token lookup indexes. + + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + indexes = self.list_indexes(exclude_token_lookup=True) + for index in indexes: + self.cypher_query("DROP INDEX " + index["name"]) + if not quiet: + stdout.write( + f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n' + ) + if not quiet: + stdout.write("\n") + + def remove_all_labels(self, stdout=None): + """ + Calls functions for dropping constraints and indexes. + + :param stdout: output stream + :return: None + """ + + if not stdout: + stdout = sys.stdout + + stdout.write("Dropping constraints...\n") + self.drop_constraints(quiet=False, stdout=stdout) + + stdout.write("Dropping indexes...\n") + self.drop_indexes(quiet=False, stdout=stdout) + + def install_all_labels(self, stdout=None): + """ + Discover all subclasses of StructuredNode in your application and execute install_labels on each. + Note: code must be loaded (imported) in order for a class to be discovered. + + :param stdout: output stream + :return: None + """ + + if not stdout or stdout is None: + stdout = sys.stdout + + def subsub(cls): # recursively return all subclasses + subclasses = cls.__subclasses__() + if not subclasses: # base case: no more subclasses + return [] + return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)] + + stdout.write("Setting up indexes and constraints...\n\n") + + i = 0 + for cls in subsub(StructuredNode): + stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") + install_labels(cls, quiet=False, stdout=stdout) + i += 1 + + if i: + stdout.write("\n") + + stdout.write(f"Finished {i} classes.\n") + + def install_labels(self, cls, quiet=True, stdout=None): + """ + Setup labels with indexes and constraints for a given class + + :param cls: StructuredNode class + :type: class + :param quiet: (default true) enable standard output + :param stdout: stdout stream + :type: bool + :return: None + """ + if not stdout or stdout is None: + stdout = sys.stdout + + if not hasattr(cls, "__label__"): + if not quiet: + stdout.write( + f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n" + ) + return + + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + self._install_node(cls, name, property, quiet, stdout) + + for _, relationship in cls.defined_properties( + aliases=False, rels=True, properties=False + ).items(): + self._install_relationship(cls, relationship, quiet, stdout) + + def _create_node_index(self, label: str, property_name: str, stdout): + try: + self.cypher_query( + f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + def _create_node_constraint(self, label: str, property_name: str, stdout): + try: + self.cypher_query( + f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name} + FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + def _create_relationship_index( + self, relationship_type: str, property_name: str, stdout + ): + try: + self.cypher_query( + f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); " + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + INDEX_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + + def _create_relationship_constraint( + self, relationship_type: str, property_name: str, stdout + ): + if self.version_is_higher_than("5.7"): + try: + self.cypher_query( + f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name} + FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE""" + ) + except ClientError as e: + if e.code in ( + RULE_ALREADY_EXISTS, + CONSTRAINT_ALREADY_EXISTS, + ): + stdout.write(f"{str(e)}\n") + else: + raise + else: + raise FeatureNotSupported( + f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher." + ) + + def _install_node(self, cls, name, property, quiet, stdout): + # Create indexes and constraints for node property + db_property = property.db_property or name + if property.index: + if not quiet: + stdout.write( + f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + self._create_node_index( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n" + ) + self._create_node_constraint( + label=cls.__label__, property_name=db_property, stdout=stdout + ) + + def _install_relationship(self, cls, relationship, quiet, stdout): + # Create indexes and constraints for relationship property + relationship_cls = relationship.definition["model"] + if relationship_cls is not None: + relationship_type = relationship.definition["relation_type"] + for prop_name, property in relationship_cls.defined_properties( + aliases=False, rels=False + ).items(): + db_property = property.db_property or prop_name + if property.index: + if not quiet: + stdout.write( + f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + self._create_relationship_index( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + elif property.unique_index: + if not quiet: + stdout.write( + f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n" + ) + self._create_relationship_constraint( + relationship_type=relationship_type, + property_name=db_property, + stdout=stdout, + ) + + +# Create a singleton instance of the database object +db = Database() + + +# Deprecated methods +def change_neo4j_password(db: Database, user, new_password): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.change_neo4j_password(user, new_password) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.change_neo4j_password(user, new_password) + + +def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=False): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.clear_neo4j_database(clear_constraints, clear_indexes) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.clear_neo4j_database(clear_constraints, clear_indexes) + + +def drop_constraints(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.drop_constraints(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.drop_constraints(quiet, stdout) + + +def drop_indexes(quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.drop_indexes(quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.drop_indexes(quiet, stdout) + + +def remove_all_labels(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.remove_all_labels(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.remove_all_labels(stdout) + + +def install_labels(cls, quiet=True, stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.install_labels(cls, quiet, stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.install_labels(cls, quiet, stdout) + + +def install_all_labels(stdout=None): + deprecated( + """ + This method has been moved to the Database singleton (db for sync, db for async). + Please use db.install_all_labels(stdout) instead. + This direct call will be removed in an upcoming version. + """ + ) + db.install_all_labels(stdout) + + +class TransactionProxy: + bookmarks: Optional[Bookmarks] = None + + def __init__(self, db: Database, access_mode=None): + self.db = db + self.access_mode = access_mode + + @ensure_connection + def __enter__(self): + print("aenter called") + self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) + self.bookmarks = None + return self + + def __exit__(self, exc_type, exc_value, traceback): + print("aexit called") + if exc_value: + self.db.rollback() + + if ( + exc_type is ClientError + and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" + ): + raise UniqueProperty(exc_value.message) + + if not exc_value: + self.last_bookmark = self.db.commit() + + def __call__(self, func): + if Util.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + def wrapper(*args, **kwargs): + with self: + print("call called") + return func(*args, **kwargs) + + return wrapper + + @property + def with_bookmark(self): + return BookmarkingAsyncTransactionProxy(self.db, self.access_mode) + + +class BookmarkingAsyncTransactionProxy(TransactionProxy): + def __call__(self, func): + if Util.is_async_code and not iscoroutinefunction(func): + raise TypeError(NOT_COROUTINE_ERROR) + + def wrapper(*args, **kwargs): + self.bookmarks = kwargs.pop("bookmarks", None) + + with self: + result = func(*args, **kwargs) + self.last_bookmark = None + + return result, self.last_bookmark + + return wrapper + + +class ImpersonationHandler: + def __init__(self, db: Database, impersonated_user: str): + self.db = db + self.impersonated_user = impersonated_user + + def __enter__(self): + self.db.impersonated_user = self.impersonated_user + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.db.impersonated_user = None + + print("\nException type:", exception_type) + print("\nException value:", exception_value) + print("\nTraceback:", exception_traceback) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + +class NodeMeta(type): + def __new__(mcs, name, bases, namespace): + namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {}) + cls = super().__new__(mcs, name, bases, namespace) + cls.DoesNotExist._model_class = cls + + if hasattr(cls, "__abstract_node__"): + delattr(cls, "__abstract_node__") + else: + if "deleted" in namespace: + raise ValueError( + "Property name 'deleted' is not allowed as it conflicts with neomodel internals." + ) + elif "id" in namespace: + raise ValueError( + """ + Property name 'id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as id is also a Neo4j internal. + """ + ) + elif "element_id" in namespace: + raise ValueError( + """ + Property name 'element_id' is not allowed as it conflicts with neomodel internals. + Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal. + """ + ) + for key, value in ( + (x, y) for x, y in namespace.items() if isinstance(y, Property) + ): + value.name, value.owner = key, cls + if hasattr(value, "setup") and callable(value.setup): + value.setup() + + # cache various groups of properies + cls.__required_properties__ = tuple( + name + for name, property in cls.defined_properties( + aliases=False, rels=False + ).items() + if property.required or property.unique_index + ) + cls.__all_properties__ = tuple( + cls.defined_properties(aliases=False, rels=False).items() + ) + cls.__all_aliases__ = tuple( + cls.defined_properties(properties=False, rels=False).items() + ) + cls.__all_relationships__ = tuple( + cls.defined_properties(aliases=False, properties=False).items() + ) + + cls.__label__ = namespace.get("__label__", name) + cls.__optional_labels__ = namespace.get("__optional_labels__", []) + + build_class_registry(cls) + + return cls + + +def build_class_registry(cls): + base_label_set = frozenset(cls.inherited_labels()) + optional_label_set = set(cls.inherited_optional_labels()) + + # Construct all possible combinations of labels + optional labels + possible_label_combinations = [ + frozenset(set(x).union(base_label_set)) + for i in range(1, len(optional_label_set) + 1) + for x in combinations(optional_label_set, i) + ] + possible_label_combinations.append(base_label_set) + + for label_set in possible_label_combinations: + if label_set not in db._NODE_CLASS_REGISTRY: + db._NODE_CLASS_REGISTRY[label_set] = cls + else: + raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY) + + +NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True}) + + +class StructuredNode(NodeBase): + """ + Base class for all node definitions to inherit from. + + If you want to create your own abstract classes set: + __abstract_node__ = True + """ + + # static properties + + __abstract_node__ = True + + # magic methods + + def __init__(self, *args, **kwargs): + if "deleted" in kwargs: + raise ValueError("deleted property is reserved for neomodel") + + for key, val in self.__all_relationships__: + self.__dict__[key] = val.build_manager(self, key) + + super().__init__(*args, **kwargs) + + def __eq__(self, other): + if not isinstance(other, (StructuredNode,)): + return False + if hasattr(self, "element_id") and hasattr(other, "element_id"): + return self.element_id == other.element_id + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return f"<{self.__class__.__name__}: {self}>" + + def __str__(self): + return repr(self.__properties__) + + # dynamic properties + + @classproperty + def nodes(cls): + """ + Returns a NodeSet object representing all nodes of the classes label + :return: NodeSet + :rtype: NodeSet + """ + from neomodel.sync_.match import NodeSet + + return NodeSet(cls) + + @property + def element_id(self): + if hasattr(self, "element_id_property"): + return self.element_id_property + return None + + # Version 4.4 support - id is deprecated in version 5.x + @property + def id(self): + try: + return int(self.element_id_property) + except (TypeError, ValueError): + raise ValueError( + "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." + ) + + # methods + + @classmethod + def _build_merge_query( + cls, merge_params, update_existing=False, lazy=False, relationship=None + ): + """ + Get a tuple of a CYPHER query and a params dict for the specified MERGE query. + + :param merge_params: The target node match parameters, each node must have a "create" key and optional "update". + :type merge_params: list of dict + :param update_existing: True to update properties of existing nodes, default False to keep existing values. + :type update_existing: bool + :rtype: tuple + """ + query_params = dict(merge_params=merge_params) + n_merge_labels = ":".join(cls.inherited_labels()) + n_merge_prm = ", ".join( + ( + f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}" + for p in cls.__required_properties__ + ) + ) + n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}" + if relationship is None: + # create "simple" unwind query + query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n " + else: + # validate relationship + if not isinstance(relationship.source, StructuredNode): + raise ValueError( + f"relationship source [{repr(relationship.source)}] is not a StructuredNode" + ) + relation_type = relationship.definition.get("relation_type") + if not relation_type: + raise ValueError( + "No relation_type is specified on provided relationship" + ) + + from neomodel.sync_.match import _rel_helper + + query_params["source_id"] = db.parse_element_id( + relationship.source.element_id + ) + query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n " + query += "WITH source\n UNWIND $merge_params as params \n " + query += "MERGE " + query += _rel_helper( + lhs="source", + rhs=n_merge, + ident=None, + relation_type=relation_type, + direction=relationship.definition["direction"], + ) + + query += "ON CREATE SET n = params.create\n " + # if update_existing, write properties on match as well + if update_existing is True: + query += "ON MATCH SET n += params.update\n" + + # close query + if lazy: + query += f"RETURN {db.get_id_method()}(n)" + else: + query += "RETURN n" + + return query, query_params + + @classmethod + def create(cls, *props, **kwargs): + """ + Call to CREATE with parameters map. A new instance will be created and saved. + + :param props: dict of properties to create the nodes. + :type props: tuple + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :type: bool + :rtype: list + """ + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + lazy = kwargs.get("lazy", False) + # create mapped query + query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)" + + # close query + if lazy: + query += f" RETURN {db.get_id_method()}(n)" + else: + query += " RETURN n" + + results = [] + for item in [ + cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props + ]: + node, _ = db.cypher_query(query, {"create_params": item}) + results.extend(node[0]) + + nodes = [cls.inflate(node) for node in results] + + if not lazy and hasattr(cls, "post_create"): + for node in nodes: + node.post_create() + + return nodes + + @classmethod + def create_or_update(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exists, + this is an atomic operation. If an instance already exists all optional properties specified will be updated. + + Note that the post_create hook isn't called after create_or_update + + :param props: List of dict arguments to get or create the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query, make sure to update only explicitly specified properties + create_or_update_params = [] + for specified, deflated in [ + (p, cls.deflate(p, skip_empty=True)) for p in props + ]: + create_or_update_params.append( + { + "create": deflated, + "update": dict( + (k, v) for k, v in deflated.items() if k in specified + ), + } + ) + query, params = cls._build_merge_query( + create_or_update_params, + update_existing=True, + relationship=relationship, + lazy=lazy, + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = db.cypher_query(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + def cypher(self, query, params=None): + """ + Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id. + + :param query: cypher query string + :type: string + :param params: query parameters + :type: dict + :return: list containing query results + :rtype: list + """ + self._pre_action_check("cypher") + params = params or {} + element_id = db.parse_element_id(self.element_id) + params.update({"self": element_id}) + return db.cypher_query(query, params) + + @hooks + def delete(self): + """ + Delete a node and its relationships + + :return: True + """ + self._pre_action_check("delete") + self.cypher( + f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self" + ) + delattr(self, "element_id_property") + self.deleted = True + return True + + @classmethod + def get_or_create(cls, *props, **kwargs): + """ + Call to MERGE with parameters map. A new instance will be created and saved if does not already exist, + this is an atomic operation. + Parameters must contain all required properties, any non required properties with defaults will be generated. + + Note that the post_create hook isn't called after get_or_create + + :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create + the entities with. + :type props: tuple + :param relationship: Optional, relationship to get/create on when new entity is created. + :param lazy: False by default, specify True to get nodes with id only without the parameters. + :rtype: list + """ + lazy = kwargs.get("lazy", False) + relationship = kwargs.get("relationship") + + # build merge query + get_or_create_params = [ + {"create": cls.deflate(p, skip_empty=True)} for p in props + ] + query, params = cls._build_merge_query( + get_or_create_params, relationship=relationship, lazy=lazy + ) + + if "streaming" in kwargs: + warnings.warn( + STREAMING_WARNING, + category=DeprecationWarning, + stacklevel=1, + ) + + # fetch and build instance for each result + results = db.cypher_query(query, params) + return [cls.inflate(r[0]) for r in results[0]] + + @classmethod + def inflate(cls, node): + """ + Inflate a raw neo4j_driver node to a neomodel node + :param node: + :return: node object + """ + # support lazy loading + if isinstance(node, str) or isinstance(node, int): + snode = cls() + snode.element_id_property = node + else: + node_properties = _get_node_properties(node) + props = {} + for key, prop in cls.__all_properties__: + # map property name from database to object property + db_property = prop.db_property or key + + if db_property in node_properties: + props[key] = prop.inflate(node_properties[db_property], node) + elif prop.has_default: + props[key] = prop.default_value() + else: + props[key] = None + + snode = cls(**props) + snode.element_id_property = node.element_id + + return snode + + @classmethod + def inherited_labels(cls): + """ + Return list of labels from nodes class hierarchy. + + :return: list + """ + return [ + scls.__label__ + for scls in cls.mro() + if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__") + ] + + @classmethod + def inherited_optional_labels(cls): + """ + Return list of optional labels from nodes class hierarchy. + + :return: list + :rtype: list + """ + return [ + label + for scls in cls.mro() + for label in getattr(scls, "__optional_labels__", []) + if not hasattr(scls, "__abstract_node__") + ] + + def labels(self): + """ + Returns list of labels tied to the node from neo4j. + + :return: list of labels + :rtype: list + """ + self._pre_action_check("labels") + result = self.cypher( + f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)" + ) + return result[0][0][0] + + def _pre_action_check(self, action): + if hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on deleted node" + ) + if not hasattr(self, "element_id"): + raise ValueError( + f"{self.__class__.__name__}.{action}() attempted on unsaved node" + ) + + def refresh(self): + """ + Reload the node from neo4j + """ + self._pre_action_check("refresh") + if hasattr(self, "element_id"): + results = self.cypher( + f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n" + ) + request = results[0] + if not request or not request[0]: + raise self.__class__.DoesNotExist("Can't refresh non existent node") + node = self.inflate(request[0][0]) + for key, val in node.__properties__.items(): + setattr(self, key, val) + else: + raise ValueError("Can't refresh unsaved node") + + @hooks + def save(self): + """ + Save the node to neo4j or raise an exception + + :return: the node instance + """ + + # create or update instance node + if hasattr(self, "element_id_property"): + # update + params = self.deflate(self.__properties__, self) + query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n" + + if params: + query += "SET " + query += ",\n".join([f"n.{key} = ${key}" for key in params]) + query += "\n" + if self.inherited_labels(): + query += "\n".join( + [f"SET n:`{label}`" for label in self.inherited_labels()] + ) + self.cypher(query, params) + elif hasattr(self, "deleted") and self.deleted: + raise ValueError( + f"{self.__class__.__name__}.save() attempted on deleted node" + ) + else: # create + result = self.create(self.__properties__) + created_node = result[0] + self.element_id_property = created_node.element_id + return self diff --git a/neomodel/match.py b/neomodel/sync_/match.py similarity index 95% rename from neomodel/match.py rename to neomodel/sync_/match.py index fb47f568..fad8b05f 100644 --- a/neomodel/match.py +++ b/neomodel/sync_/match.py @@ -4,12 +4,11 @@ from dataclasses import dataclass from typing import Optional -from .core import StructuredNode, db -from .exceptions import MultipleNodesReturned -from .match_q import Q, QBase -from .properties import AliasProperty - -OUTGOING, INCOMING, EITHER = 1, -1, 0 +from neomodel.exceptions import MultipleNodesReturned +from neomodel.match_q import Q, QBase +from neomodel.properties import AliasProperty +from neomodel.sync_.core import StructuredNode, db +from neomodel.util import INCOMING, OUTGOING def _rel_helper( @@ -505,7 +504,7 @@ def build_node(self, node): _node_lookup = f"MATCH ({ident}) WHERE {db.get_id_method()}({ident})=${place_holder} WITH {ident}" self._ast.lookup = _node_lookup - self._query_params[place_holder] = node.element_id + self._query_params[place_holder] = db.parse_element_id(node.element_id) self._ast.return_clause = ident self._ast.result_class = node.__class__ @@ -732,24 +731,42 @@ def all(self, lazy=False): :return: list of nodes :rtype: list """ - return self.query_cls(self).build_ast()._execute(lazy) + ast = self.query_cls(self).build_ast() + return ast._execute(lazy) def __iter__(self): - return (i for i in self.query_cls(self).build_ast()._execute()) + ast = self.query_cls(self).build_ast() + for i in ast._execute(): + yield i def __len__(self): - return self.query_cls(self).build_ast()._count() + ast = self.query_cls(self).build_ast() + return ast._count() def __bool__(self): - return self.query_cls(self).build_ast()._count() > 0 + """ + Override for __bool__ dunder method. + :return: True if the set contains any nodes, False otherwise + :rtype: bool + """ + ast = self.query_cls(self).build_ast() + _count = ast._count() + return _count > 0 def __nonzero__(self): - return self.query_cls(self).build_ast()._count() > 0 + """ + Override for __bool__ dunder method. + :return: True if the set contains any node, False otherwise + :rtype: bool + """ + return self.__bool__() def __contains__(self, obj): if isinstance(obj, StructuredNode): if hasattr(obj, "element_id") and obj.element_id is not None: - return self.query_cls(self).build_ast()._contains(obj.element_id) + ast = self.query_cls(self).build_ast() + obj_element_id = db.parse_element_id(obj.element_id) + return ast._contains(obj_element_id) raise ValueError("Unsaved node: " + repr(obj)) raise ValueError("Expecting StructuredNode instance") @@ -770,7 +787,9 @@ def __getitem__(self, key): self.skip = key self.limit = 1 - return self.query_cls(self).build_ast()._execute()[0] + ast = self.query_cls(self).build_ast() + _items = ast._execute() + return _items[0] return None @@ -810,11 +829,15 @@ def __init__(self, source): self.relations_to_fetch: list = [] + def __await__(self): + return self.all().__await__() + def _get(self, limit=None, lazy=False, **kwargs): self.filter(**kwargs) if limit: self.limit = limit - return self.query_cls(self).build_ast()._execute(lazy) + ast = self.query_cls(self).build_ast() + return ast._execute(lazy) def get(self, lazy=False, **kwargs): """ @@ -849,7 +872,7 @@ def first(self, **kwargs): :param kwargs: same syntax as `filter()` :return: node """ - result = result = self._get(limit=1, **kwargs) + result = self._get(limit=1, **kwargs) if result: return result[0] else: @@ -986,6 +1009,9 @@ class Traversal(BaseSet): :type defintion: :class:`dict` """ + def __await__(self): + return self.all().__await__() + def __init__(self, source, name, definition): """ Create a traversal diff --git a/neomodel/path.py b/neomodel/sync_/path.py similarity index 88% rename from neomodel/path.py rename to neomodel/sync_/path.py index 5f063d11..62a49fe7 100644 --- a/neomodel/path.py +++ b/neomodel/sync_/path.py @@ -1,14 +1,14 @@ from neo4j.graph import Path -from .core import db -from .relationship import StructuredRel -from .exceptions import RelationshipClassNotDefined + +from neomodel.sync_.core import db +from neomodel.sync_.relationship import StructuredRel class NeomodelPath(Path): """ Represents paths within neomodel. - This object is instantiated when you include whole paths in your ``cypher_query()`` + This object is instantiated when you include whole paths in your ``cypher_query()`` result sets and turn ``resolve_objects`` to True. That is, any query of the form: @@ -16,7 +16,7 @@ class NeomodelPath(Path): MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p - ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already + ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already resolved to their neomodel objects if such mapping is possible. @@ -25,8 +25,9 @@ class NeomodelPath(Path): :type nodes: List[StructuredNode] :type relationships: List[StructuredRel] """ + def __init__(self, a_neopath): - self._nodes=[] + self._nodes = [] self._relationships = [] for a_node in a_neopath.nodes: @@ -42,6 +43,7 @@ def __init__(self, a_neopath): else: new_rel = StructuredRel.inflate(a_relationship) self._relationships.append(new_rel) + @property def nodes(self): return self._nodes @@ -49,5 +51,3 @@ def nodes(self): @property def relationships(self): return self._relationships - - diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py new file mode 100644 index 00000000..85452f0b --- /dev/null +++ b/neomodel/sync_/property_manager.py @@ -0,0 +1,109 @@ +import types + +from neomodel.exceptions import RequiredProperty +from neomodel.properties import AliasProperty, Property + + +def display_for(key): + def display_choice(self): + return getattr(self.__class__, key).choices[getattr(self, key)] + + return display_choice + + +class PropertyManager: + """ + Common methods for handling properties on node and relationship objects. + """ + + def __init__(self, **kwargs): + properties = getattr(self, "__all_properties__", None) + if properties is None: + properties = self.defined_properties(rels=False, aliases=False).items() + for name, property in properties: + if kwargs.get(name) is None: + if getattr(property, "has_default", False): + setattr(self, name, property.default_value()) + else: + setattr(self, name, None) + else: + setattr(self, name, kwargs[name]) + + if getattr(property, "choices", None): + setattr( + self, + f"get_{name}_display", + types.MethodType(display_for(name), self), + ) + + if name in kwargs: + del kwargs[name] + + aliases = getattr(self, "__all_aliases__", None) + if aliases is None: + aliases = self.defined_properties( + aliases=True, rels=False, properties=False + ).items() + for name, property in aliases: + if name in kwargs: + setattr(self, name, kwargs[name]) + del kwargs[name] + + # undefined properties (for magic @prop.setters etc) + for name, property in kwargs.items(): + setattr(self, name, property) + + @property + def __properties__(self): + from neomodel.sync_.relationship_manager import RelationshipManager + + return dict( + (name, value) + for name, value in vars(self).items() + if not name.startswith("_") + and not callable(value) + and not isinstance( + value, + ( + RelationshipManager, + AliasProperty, + ), + ) + ) + + @classmethod + def deflate(cls, properties, obj=None, skip_empty=False): + # deflate dict ready to be stored + deflated = {} + for name, property in cls.defined_properties(aliases=False, rels=False).items(): + db_property = property.db_property or name + if properties.get(name) is not None: + deflated[db_property] = property.deflate(properties[name], obj) + elif property.has_default: + deflated[db_property] = property.deflate(property.default_value(), obj) + elif property.required: + raise RequiredProperty(name, cls) + elif not skip_empty: + deflated[db_property] = None + return deflated + + @classmethod + def defined_properties(cls, aliases=True, properties=True, rels=True): + from neomodel.sync_.relationship_manager import RelationshipDefinition + + props = {} + for baseclass in reversed(cls.__mro__): + props.update( + dict( + (name, property) + for name, property in vars(baseclass).items() + if (aliases and isinstance(property, AliasProperty)) + or ( + properties + and isinstance(property, Property) + and not isinstance(property, AliasProperty) + ) + or (rels and isinstance(property, RelationshipDefinition)) + ) + ) + return props diff --git a/neomodel/relationship.py b/neomodel/sync_/relationship.py similarity index 71% rename from neomodel/relationship.py rename to neomodel/sync_/relationship.py index 8df56c47..0a199575 100644 --- a/neomodel/relationship.py +++ b/neomodel/sync_/relationship.py @@ -1,8 +1,9 @@ -import warnings +from neomodel.hooks import hooks +from neomodel.properties import Property +from neomodel.sync_.core import db +from neomodel.sync_.property_manager import PropertyManager -from .core import db -from .hooks import hooks -from .properties import Property, PropertyManager +ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." class RelationshipMeta(type): @@ -50,57 +51,42 @@ def __init__(self, *args, **kwargs): @property def element_id(self): - return ( - int(self.element_id_property) - if db.database_version.startswith("4") - else self.element_id_property - ) + if hasattr(self, "element_id_property"): + return self.element_id_property @property def _start_node_element_id(self): - return ( - int(self._start_node_element_id_property) - if db.database_version.startswith("4") - else self._start_node_element_id_property - ) + if hasattr(self, "_start_node_element_id_property"): + return self._start_node_element_id_property @property def _end_node_element_id(self): - return ( - int(self._end_node_element_id_property) - if db.database_version.startswith("4") - else self._end_node_element_id_property - ) + if hasattr(self, "_end_node_element_id_property"): + return self._end_node_element_id_property # Version 4.4 support - id is deprecated in version 5.x @property def id(self): try: return int(self.element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc # Version 4.4 support - id is deprecated in version 5.x @property def _start_node_id(self): try: return int(self._start_node_element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc # Version 4.4 support - id is deprecated in version 5.x @property def _end_node_id(self): try: return int(self._end_node_element_id_property) - except (TypeError, ValueError): - raise ValueError( - "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()." - ) + except (TypeError, ValueError) as exc: + raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc @hooks def save(self): @@ -112,7 +98,7 @@ def save(self): props = self.deflate(self.__properties__) query = f"MATCH ()-[r]->() WHERE {db.get_id_method()}(r)=$self " query += "".join([f" SET r.{key} = ${key}" for key in props]) - props["self"] = self.element_id + props["self"] = db.parse_element_id(self.element_id) db.cypher_query(query, props) @@ -124,16 +110,16 @@ def start_node(self): :return: StructuredNode """ - test = db.cypher_query( + results = db.cypher_query( f""" MATCH (aNode) WHERE {db.get_id_method()}(aNode)=$start_node_element_id RETURN aNode """, - {"start_node_element_id": self._start_node_element_id}, + {"start_node_element_id": db.parse_element_id(self._start_node_element_id)}, resolve_objects=True, ) - return test[0][0][0] + return results[0][0][0] def end_node(self): """ @@ -141,15 +127,16 @@ def end_node(self): :return: StructuredNode """ - return db.cypher_query( + results = db.cypher_query( f""" MATCH (aNode) WHERE {db.get_id_method()}(aNode)=$end_node_element_id RETURN aNode """, - {"end_node_element_id": self._end_node_element_id}, + {"end_node_element_id": db.parse_element_id(self._end_node_element_id)}, resolve_objects=True, - )[0][0][0] + ) + return results[0][0][0] @classmethod def inflate(cls, rel): diff --git a/neomodel/relationship_manager.py b/neomodel/sync_/relationship_manager.py similarity index 93% rename from neomodel/relationship_manager.py rename to neomodel/sync_/relationship_manager.py index 1e9cf79e..8b8ae96a 100644 --- a/neomodel/relationship_manager.py +++ b/neomodel/sync_/relationship_manager.py @@ -3,19 +3,17 @@ import sys from importlib import import_module -from .core import db -from .exceptions import NotConnected, RelationshipClassRedefined -from .match import ( +from neomodel.exceptions import NotConnected, RelationshipClassRedefined +from neomodel.sync_.core import db +from neomodel.sync_.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper +from neomodel.sync_.relationship import StructuredRel +from neomodel.util import ( EITHER, INCOMING, OUTGOING, - NodeSet, - Traversal, - _rel_helper, - _rel_merge_helper, + _get_node_properties, + enumerate_traceback, ) -from .relationship import StructuredRel -from .util import _get_node_properties, enumerate_traceback # basestring python 3.x fallback try: @@ -66,6 +64,9 @@ def __str__(self): return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'" + def __await__(self): + return self.all().__await__() + def _check_node(self, obj): """check for valid node i.e correct class and is saved""" if not issubclass(type(obj), self.definition["node_class"]): @@ -124,13 +125,14 @@ def connect(self, node, properties=None): "MERGE" + new_rel ) - params["them"] = node.element_id + params["them"] = db.parse_element_id(node.element_id) if not rel_model: self.source.cypher(q, params) return True - rel_ = self.source.cypher(q + " RETURN r", params)[0][0][0] + results = self.source.cypher(q + " RETURN r", params) + rel_ = results[0][0][0] rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node) if hasattr(rel_instance, "post_save"): @@ -166,7 +168,8 @@ def relationship(self, node): + my_rel + f" WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r LIMIT 1" ) - rels = self.source.cypher(q, {"them": node.element_id})[0] + results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) + rels = results[0] if not rels: return @@ -186,7 +189,8 @@ def all_relationships(self, node): my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition) q = f"MATCH {my_rel} WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r " - rels = self.source.cypher(q, {"them": node.element_id})[0] + results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) + rels = results[0] if not rels: return [] @@ -223,12 +227,14 @@ def reconnect(self, old_node, new_node): old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition) # get list of properties on the existing rel + old_node_element_id = db.parse_element_id(old_node.element_id) + new_node_element_id = db.parse_element_id(new_node.element_id) result, _ = self.source.cypher( f""" MATCH (us), (old) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old MATCH {old_rel} RETURN r """, - {"old": old_node.element_id}, + {"old": old_node_element_id}, ) if result: node_properties = _get_node_properties(result[0][0]) @@ -249,7 +255,7 @@ def reconnect(self, old_node, new_node): q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties]) q += " WITH r DELETE r" - self.source.cypher(q, {"old": old_node.element_id, "new": new_node.element_id}) + self.source.cypher(q, {"old": old_node_element_id, "new": new_node_element_id}) @check_source def disconnect(self, node): @@ -264,7 +270,7 @@ def disconnect(self, node): MATCH (a), (b) WHERE {db.get_id_method()}(a)=$self and {db.get_id_method()}(b)=$them MATCH {rel} DELETE r """ - self.source.cypher(q, {"them": node.element_id}) + self.source.cypher(q, {"them": db.parse_element_id(node.element_id)}) @check_source def disconnect_all(self): @@ -345,7 +351,8 @@ def single(self): :return: StructuredNode """ try: - return self[0] + rels = self + return rels[0] except IndexError: pass diff --git a/neomodel/util.py b/neomodel/util.py index 74e88250..1b88e407 100644 --- a/neomodel/util.py +++ b/neomodel/util.py @@ -1,630 +1,6 @@ -import logging -import os -import sys -import time import warnings -from threading import local -from typing import Optional, Sequence -from urllib.parse import quote, unquote, urlparse -from neo4j import DEFAULT_DATABASE, Driver, GraphDatabase, basic_auth -from neo4j.api import Bookmarks -from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired -from neo4j.graph import Node, Path, Relationship - -from neomodel import config, core -from neomodel.exceptions import ( - ConstraintValidationFailed, - FeatureNotSupported, - NodeClassNotDefined, - RelationshipClassNotDefined, - UniqueProperty, -) - -logger = logging.getLogger(__name__) - - -# make sure the connection url has been set prior to executing the wrapped function -def ensure_connection(func): - def wrapper(self, *args, **kwargs): - # Sort out where to find url - if hasattr(self, "db"): - _db = self.db - else: - _db = self - - if not _db.driver: - if hasattr(config, "DRIVER") and config.DRIVER: - _db.set_connection(driver=config.DRIVER) - elif config.DATABASE_URL: - _db.set_connection(url=config.DATABASE_URL) - - return func(self, *args, **kwargs) - - return wrapper - - -def change_neo4j_password(db, user, new_password): - db.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'") - - -def clear_neo4j_database(db, clear_constraints=False, clear_indexes=False): - db.cypher_query( - """ - MATCH (a) - CALL { WITH a DETACH DELETE a } - IN TRANSACTIONS OF 5000 rows - """ - ) - if clear_constraints: - core.drop_constraints() - if clear_indexes: - core.drop_indexes() - - -class Database(local): - """ - A singleton object via which all operations from neomodel to the Neo4j backend are handled with. - """ - - _NODE_CLASS_REGISTRY = {} - - def __init__(self): - self._active_transaction = None - self.url = None - self.driver = None - self._session = None - self._pid = None - self._database_name = DEFAULT_DATABASE - self.protocol_version = None - self._database_version = None - self._database_edition = None - self.impersonated_user = None - - def set_connection(self, url: str = None, driver: Driver = None): - """ - Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance. - - Args: - url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname. - When provided, a Neo4j driver instance will be created by neomodel. - - driver (neo4j.Driver): Optionally, a pre-created driver instance. - When provided, neomodel will not create a driver instance but use this one instead. - """ - if driver: - self.driver = driver - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME - elif url: - self._parse_driver_from_url(url=url) - - self._pid = os.getpid() - self._active_transaction = None - # Set to default database if it hasn't been set before - if self._database_name is None: - self._database_name = DEFAULT_DATABASE - - # Getting the information about the database version requires a connection to the database - self._database_version = None - self._database_edition = None - self._update_database_version() - - def _parse_driver_from_url(self, url: str) -> None: - """Parse the driver information from the given URL and initialize the driver. - - Args: - url (str): The URL to parse. - - Raises: - ValueError: If the URL format is not as expected. - - Returns: - None - Sets the driver and database_name as class properties - """ - p_start = url.replace(":", "", 1).find(":") + 2 - p_end = url.rfind("@") - password = url[p_start:p_end] - url = url.replace(password, quote(password)) - parsed_url = urlparse(url) - - valid_schemas = [ - "bolt", - "bolt+s", - "bolt+ssc", - "bolt+routing", - "neo4j", - "neo4j+s", - "neo4j+ssc", - ] - - if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas: - credentials, hostname = parsed_url.netloc.rsplit("@", 1) - username, password = credentials.split(":") - password = unquote(password) - database_name = parsed_url.path.strip("/") - else: - raise ValueError( - f"Expecting url format: bolt://user:password@localhost:7687 got {url}" - ) - - options = { - "auth": basic_auth(username, password), - "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT, - "connection_timeout": config.CONNECTION_TIMEOUT, - "keep_alive": config.KEEP_ALIVE, - "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME, - "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE, - "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME, - "resolver": config.RESOLVER, - "user_agent": config.USER_AGENT, - } - - if "+s" not in parsed_url.scheme: - options["encrypted"] = config.ENCRYPTED - options["trusted_certificates"] = config.TRUSTED_CERTIFICATES - - self.driver = GraphDatabase.driver( - parsed_url.scheme + "://" + hostname, **options - ) - self.url = url - # The database name can be provided through the url or the config - if database_name == "": - if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME: - self._database_name = config.DATABASE_NAME - else: - self._database_name = database_name - - def close_connection(self): - """ - Closes the currently open driver. - The driver should always be closed at the end of the application's lifecyle. - """ - self._database_version = None - self._database_edition = None - self._database_name = None - self.driver.close() - self.driver = None - - @property - def database_version(self): - if self._database_version is None: - self._update_database_version() - - return self._database_version - - @property - def database_edition(self): - if self._database_edition is None: - self._update_database_version() - - return self._database_edition - - @property - def transaction(self): - """ - Returns the current transaction object - """ - return TransactionProxy(self) - - @property - def write_transaction(self): - return TransactionProxy(self, access_mode="WRITE") - - @property - def read_transaction(self): - return TransactionProxy(self, access_mode="READ") - - def impersonate(self, user: str) -> "ImpersonationHandler": - """All queries executed within this context manager will be executed as impersonated user - - Args: - user (str): User to impersonate - - Returns: - ImpersonationHandler: Context manager to set/unset the user to impersonate - """ - if self.database_edition != "enterprise": - raise FeatureNotSupported( - "Impersonation is only available in Neo4j Enterprise edition" - ) - return ImpersonationHandler(self, impersonated_user=user) - - @ensure_connection - def begin(self, access_mode=None, **parameters): - """ - Begins a new transaction. Raises SystemError if a transaction is already active. - """ - if ( - hasattr(self, "_active_transaction") - and self._active_transaction is not None - ): - raise SystemError("Transaction in progress") - self._session = self.driver.session( - default_access_mode=access_mode, - database=self._database_name, - impersonated_user=self.impersonated_user, - **parameters, - ) - self._active_transaction = self._session.begin_transaction() - - @ensure_connection - def commit(self): - """ - Commits the current transaction and closes its session - - :return: last_bookmarks - """ - try: - self._active_transaction.commit() - last_bookmarks = self._session.last_bookmarks() - finally: - # In case when something went wrong during - # committing changes to the database - # we have to close an active transaction and session. - self._active_transaction.close() - self._session.close() - self._active_transaction = None - self._session = None - - return last_bookmarks - - @ensure_connection - def rollback(self): - """ - Rolls back the current transaction and closes its session - """ - try: - self._active_transaction.rollback() - finally: - # In case when something went wrong during changes rollback, - # we have to close an active transaction and session - self._active_transaction.close() - self._session.close() - self._active_transaction = None - self._session = None - - def _update_database_version(self): - """ - Updates the database server information when it is required - """ - try: - results = self.cypher_query( - "CALL dbms.components() yield versions, edition return versions[0], edition" - ) - self._database_version = results[0][0][0] - self._database_edition = results[0][0][1] - except ServiceUnavailable: - # The database server is not running yet - pass - - def _object_resolution(self, object_to_resolve): - """ - Performs in place automatic object resolution on a result - returned by cypher_query. - - The function operates recursively in order to be able to resolve Nodes - within nested list structures and Path objects. Not meant to be called - directly, used primarily by _result_resolution. - - :param object_to_resolve: A result as returned by cypher_query. - :type Any: - - :return: An instantiated object. - """ - # Below is the original comment that came with the code extracted in - # this method. It is not very clear but I decided to keep it just in - # case - # - # - # For some reason, while the type of `a_result_attribute[1]` - # as reported by the neo4j driver is `Node` for Node-type data - # retrieved from the database. - # When the retrieved data are Relationship-Type, - # the returned type is `abc.[REL_LABEL]` which is however - # a descendant of Relationship. - # Consequently, the type checking was changed for both - # Node, Relationship objects - if isinstance(object_to_resolve, Node): - return self._NODE_CLASS_REGISTRY[ - frozenset(object_to_resolve.labels) - ].inflate(object_to_resolve) - - if isinstance(object_to_resolve, Relationship): - rel_type = frozenset([object_to_resolve.type]) - return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve) - - if isinstance(object_to_resolve, Path): - from .path import NeomodelPath - - return NeomodelPath(object_to_resolve) - - if isinstance(object_to_resolve, list): - return self._result_resolution([object_to_resolve]) - - return object_to_resolve - - def _result_resolution(self, result_list): - """ - Performs in place automatic object resolution on a set of results - returned by cypher_query. - - The function operates recursively in order to be able to resolve Nodes - within nested list structures. Not meant to be called directly, - used primarily by cypher_query. - - :param result_list: A list of results as returned by cypher_query. - :type list: - - :return: A list of instantiated objects. - """ - - # Object resolution occurs in-place - for a_result_item in enumerate(result_list): - for a_result_attribute in enumerate(a_result_item[1]): - try: - # Primitive types should remain primitive types, - # Nodes to be resolved to native objects - resolved_object = a_result_attribute[1] - - resolved_object = self._object_resolution(resolved_object) - - result_list[a_result_item[0]][ - a_result_attribute[0] - ] = resolved_object - - except KeyError as exc: - # Not being able to match the label set of a node with a known object results - # in a KeyError in the internal dictionary used for resolution. If it is impossible - # to match, then raise an exception with more details about the error. - if isinstance(a_result_attribute[1], Node): - raise NodeClassNotDefined( - a_result_attribute[1], self._NODE_CLASS_REGISTRY - ) from exc - - if isinstance(a_result_attribute[1], Relationship): - raise RelationshipClassNotDefined( - a_result_attribute[1], self._NODE_CLASS_REGISTRY - ) from exc - - return result_list - - @ensure_connection - def cypher_query( - self, - query, - params=None, - handle_unique=True, - retry_on_session_expire=False, - resolve_objects=False, - ): - """ - Runs a query on the database and returns a list of results and their headers. - - :param query: A CYPHER query - :type: str - :param params: Dictionary of parameters - :type: dict - :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors - :type: bool - :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired. - If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance. - :type: bool - :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically - :type: bool - """ - - if self._active_transaction: - # Use current session is a transaction is currently active - results, meta = self._run_cypher_query( - self._active_transaction, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ) - else: - # Otherwise create a new session in a with to dispose of it after it has been run - with self.driver.session( - database=self._database_name, impersonated_user=self.impersonated_user - ) as session: - results, meta = self._run_cypher_query( - session, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ) - - return results, meta - - def _run_cypher_query( - self, - session, - query, - params, - handle_unique, - retry_on_session_expire, - resolve_objects, - ): - try: - # Retrieve the data - start = time.time() - response = session.run(query, params) - results, meta = [list(r.values()) for r in response], response.keys() - end = time.time() - - if resolve_objects: - # Do any automatic resolution required - results = self._result_resolution(results) - - except ClientError as e: - if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed": - if "already exists with label" in e.message and handle_unique: - raise UniqueProperty(e.message) from e - - raise ConstraintValidationFailed(e.message) from e - exc_info = sys.exc_info() - raise exc_info[1].with_traceback(exc_info[2]) - except SessionExpired: - if retry_on_session_expire: - self.set_connection(url=self.url) - return self.cypher_query( - query=query, - params=params, - handle_unique=handle_unique, - retry_on_session_expire=False, - ) - raise - - tte = end - start - if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float( - os.environ.get("NEOMODEL_SLOW_QUERIES", 0) - ): - logger.debug( - "query: " - + query - + "\nparams: " - + repr(params) - + f"\ntook: {tte:.2g}s\n" - ) - - return results, meta - - def get_id_method(self) -> str: - if self.database_version.startswith("4"): - return "id" - else: - return "elementId" - - def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]: - """Returns all indexes existing in the database - - Arguments: - exclude_token_lookup[bool]: Exclude automatically create token lookup indexes - - Returns: - Sequence[dict]: List of dictionaries, each entry being an index definition - """ - indexes, meta_indexes = self.cypher_query("SHOW INDEXES") - indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes] - - if exclude_token_lookup: - indexes_as_dict = [ - obj for obj in indexes_as_dict if obj["type"] != "LOOKUP" - ] - - return indexes_as_dict - - def list_constraints(self) -> Sequence[dict]: - """Returns all constraints existing in the database - - Returns: - Sequence[dict]: List of dictionaries, each entry being a constraint definition - """ - constraints, meta_constraints = self.cypher_query("SHOW CONSTRAINTS") - constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints] - - return constraints_as_dict - - def version_is_higher_than(self, version_tag: str) -> bool: - """Returns true if the database version is higher or equal to a given tag - - Args: - version_tag (str): The version to compare against - - Returns: - bool: True if the database version is higher or equal to the given version - """ - return version_tag_to_integer(self.database_version) >= version_tag_to_integer( - version_tag - ) - - def edition_is_enterprise(self) -> bool: - """Returns true if the database edition is enterprise - - Returns: - bool: True if the database edition is enterprise - """ - return self.database_edition == "enterprise" - - -class TransactionProxy: - bookmarks: Optional[Bookmarks] = None - - def __init__(self, db, access_mode=None): - self.db = db - self.access_mode = access_mode - - @ensure_connection - def __enter__(self): - self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks) - self.bookmarks = None - return self - - def __exit__(self, exc_type, exc_value, traceback): - if exc_value: - self.db.rollback() - - if ( - exc_type is ClientError - and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed" - ): - raise UniqueProperty(exc_value.message) - - if not exc_value: - self.last_bookmark = self.db.commit() - - def __call__(self, func): - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapper - - @property - def with_bookmark(self): - return BookmarkingTransactionProxy(self.db, self.access_mode) - - -class ImpersonationHandler: - def __init__(self, db, impersonated_user: str): - self.db = db - self.impersonated_user = impersonated_user - - def __enter__(self): - self.db.impersonated_user = self.impersonated_user - return self - - def __exit__(self, exception_type, exception_value, exception_traceback): - self.db.impersonated_user = None - - print("\nException type:", exception_type) - print("\nException value:", exception_value) - print("\nTraceback:", exception_traceback) - - def __call__(self, func): - def wrapper(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapper - - -class BookmarkingTransactionProxy(TransactionProxy): - def __call__(self, func): - def wrapper(*args, **kwargs): - self.bookmarks = kwargs.pop("bookmarks", None) - - with self: - result = func(*args, **kwargs) - self.last_bookmark = None - - return result, self.last_bookmark - - return wrapper +OUTGOING, INCOMING, EITHER = 1, -1, 0 def deprecated(message): diff --git a/pyproject.toml b/pyproject.toml index f15c28b8..acc36850 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ authors = [ maintainers = [ {name = "Marius Conjeaud", email = "marius.conjeaud@outlook.com"}, {name = "Athanasios Anastasiou", email = "athanastasiou@gmail.com"}, - {name = "Cristina Escalante"}, ] description = "An object mapper for the neo4j graph database." readme = "README.md" @@ -35,7 +34,9 @@ changelog = "https://github.com/neo4j-contrib/neomodel/releases" [project.optional-dependencies] dev = [ + "unasync", "pytest>=7.1", + "pytest-asyncio", "pytest-cov>=4.0", "pre-commit", "black", @@ -61,7 +62,7 @@ testpaths = "test" [tool.isort] profile = 'black' -src_paths = ['neomodel'] +src_paths = ['neomodel','test'] [tool.pylint.'MESSAGES CONTROL'] disable = 'missing-module-docstring,redefined-builtin,missing-class-docstring,missing-function-docstring,consider-using-f-string,line-too-long' diff --git a/requirements-dev.txt b/requirements-dev.txt index 2e7a31f0..bf3fa116 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # neomodel -e .[pandas,numpy] +unasync>=0.5.0 pytest>=7.1 pytest-cov>=4.0 pre-commit diff --git a/test/_async_compat/__init__.py b/test/_async_compat/__init__.py new file mode 100644 index 00000000..342678c3 --- /dev/null +++ b/test/_async_compat/__init__.py @@ -0,0 +1,17 @@ +from .mark_decorator import ( + AsyncTestDecorators, + TestDecorators, + mark_async_session_auto_fixture, + mark_async_test, + mark_sync_session_auto_fixture, + mark_sync_test, +) + +__all__ = [ + "AsyncTestDecorators", + "mark_async_test", + "mark_sync_test", + "TestDecorators", + "mark_async_session_auto_fixture", + "mark_sync_session_auto_fixture", +] diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py new file mode 100644 index 00000000..a8c5eead --- /dev/null +++ b/test/_async_compat/mark_decorator.py @@ -0,0 +1,21 @@ +import pytest +import pytest_asyncio + +mark_async_test = pytest.mark.asyncio +mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True) +mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True) + + +def mark_sync_test(f): + return f + + +class AsyncTestDecorators: + mark_async_only_test = mark_async_test + + +class TestDecorators: + @staticmethod + def mark_async_only_test(f): + skip_decorator = pytest.mark.skip("Async only test") + return skip_decorator(f) diff --git a/test/async_/__init__.py b/test/async_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/async_/conftest.py b/test/async_/conftest.py new file mode 100644 index 00000000..493ff12c --- /dev/null +++ b/test/async_/conftest.py @@ -0,0 +1,60 @@ +import asyncio +import os +import warnings +from test._async_compat import mark_async_session_auto_fixture + +import pytest + +from neomodel import adb, config + + +@mark_async_session_auto_fixture +async def setup_neo4j_session(request, event_loop): + """ + Provides initial connection to the database and sets up the rest of the test suite + + :param request: The request object. Please see `_ + :type Request object: For more information please see `_ + """ + + warnings.simplefilter("default") + + config.DATABASE_URL = os.environ.get( + "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" + ) + + # Clear the database if required + database_is_populated, _ = await adb.cypher_query( + "MATCH (a) return count(a)>0 as database_is_populated" + ) + if database_is_populated[0][0] and not request.config.getoption("resetdb"): + raise SystemError( + "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." + ) + + await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + + await adb.install_all_labels() + + await adb.cypher_query( + "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" + ) + db_edition = await adb.database_edition + if db_edition == "enterprise": + await adb.cypher_query("GRANT ROLE publisher TO troygreene") + await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + + +@mark_async_session_auto_fixture +async def cleanup(event_loop): + yield + await adb.close_connection() + + +@pytest.fixture(scope="session") +def event_loop(): + """Overrides pytest default function scoped event loop""" + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() diff --git a/test/async_/test_alias.py b/test/async_/test_alias.py new file mode 100644 index 00000000..b2b9a3a8 --- /dev/null +++ b/test/async_/test_alias.py @@ -0,0 +1,34 @@ +from test._async_compat import mark_async_test + +from neomodel import AliasProperty, AsyncStructuredNode, StringProperty + + +class MagicProperty(AliasProperty): + def setup(self): + self.owner.setup_hook_called = True + + +class AliasTestNode(AsyncStructuredNode): + name = StringProperty(unique_index=True) + full_name = AliasProperty(to="name") + long_name = MagicProperty(to="name") + + +@mark_async_test +async def test_property_setup_hook(): + timmy = await AliasTestNode(long_name="timmy").save() + assert AliasTestNode.setup_hook_called + assert timmy.name == "timmy" + + +@mark_async_test +async def test_alias(): + jim = await AliasTestNode(full_name="Jim").save() + assert jim.name == "Jim" + assert jim.full_name == "Jim" + assert "full_name" not in AliasTestNode.deflate(jim.__properties__) + jim = await AliasTestNode.nodes.get(full_name="Jim") + assert jim + assert jim.name == "Jim" + assert jim.full_name == "Jim" + assert "full_name" not in AliasTestNode.deflate(jim.__properties__) diff --git a/test/async_/test_batch.py b/test/async_/test_batch.py new file mode 100644 index 00000000..653dce0d --- /dev/null +++ b/test/async_/test_batch.py @@ -0,0 +1,138 @@ +from test._async_compat import mark_async_test + +from pytest import raises + +from neomodel import ( + AsyncRelationshipFrom, + AsyncRelationshipTo, + AsyncStructuredNode, + IntegerProperty, + StringProperty, + UniqueIdProperty, + config, +) +from neomodel._async_compat.util import AsyncUtil +from neomodel.exceptions import DeflateError, UniqueProperty + +config.AUTO_INSTALL_LABELS = True + + +class UniqueUser(AsyncStructuredNode): + uid = UniqueIdProperty() + name = StringProperty() + age = IntegerProperty() + + +@mark_async_test +async def test_unique_id_property_batch(): + users = await UniqueUser.create( + {"name": "bob", "age": 2}, {"name": "ben", "age": 3} + ) + + assert users[0].uid != users[1].uid + + users = await UniqueUser.get_or_create( + {"uid": users[0].uid}, {"name": "bill", "age": 4} + ) + + assert users[0].name == "bob" + assert users[1].uid + + +class Customer(AsyncStructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +@mark_async_test +async def test_batch_create(): + users = await Customer.create( + {"email": "jim1@aol.com", "age": 11}, + {"email": "jim2@aol.com", "age": 7}, + {"email": "jim3@aol.com", "age": 9}, + {"email": "jim4@aol.com", "age": 7}, + {"email": "jim5@aol.com", "age": 99}, + ) + assert len(users) == 5 + assert users[0].age == 11 + assert users[1].age == 7 + assert users[1].email == "jim2@aol.com" + assert await Customer.nodes.get(email="jim1@aol.com") + + +@mark_async_test +async def test_batch_create_or_update(): + users = await Customer.create_or_update( + {"email": "merge1@aol.com", "age": 11}, + {"email": "merge2@aol.com"}, + {"email": "merge3@aol.com", "age": 1}, + {"email": "merge2@aol.com", "age": 2}, + ) + assert len(users) == 4 + assert users[1] == users[3] + merge_1: Customer = await Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 11 + + more_users = await Customer.create_or_update( + {"email": "merge1@aol.com", "age": 22}, + {"email": "merge4@aol.com", "age": None}, + ) + assert len(more_users) == 2 + assert users[0] == more_users[0] + merge_1 = await Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 22 + + +@mark_async_test +async def test_batch_validation(): + # test validation in batch create + with raises(DeflateError): + await Customer.create( + {"email": "jim1@aol.com", "age": "x"}, + ) + + +@mark_async_test +async def test_batch_index_violation(): + for u in await Customer.nodes: + await u.delete() + + users = await Customer.create( + {"email": "jim6@aol.com", "age": 3}, + ) + assert users + with raises(UniqueProperty): + await Customer.create( + {"email": "jim6@aol.com", "age": 3}, + {"email": "jim7@aol.com", "age": 5}, + ) + + # not found + if AsyncUtil.is_async_code: + assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool() + else: + assert not Customer.nodes.filter(email="jim7@aol.com") + + +class Dog(AsyncStructuredNode): + name = StringProperty(required=True) + owner = AsyncRelationshipTo("Person", "owner") + + +class Person(AsyncStructuredNode): + name = StringProperty(unique_index=True) + pets = AsyncRelationshipFrom("Dog", "owner") + + +@mark_async_test +async def test_get_or_create_with_rel(): + create_bob = await Person.get_or_create({"name": "Bob"}) + bob = create_bob[0] + bobs_gizmo = await Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) + + create_tim = await Person.get_or_create({"name": "Tim"}) + tim = create_tim[0] + tims_gizmo = await Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) + + # not the same gizmo + assert bobs_gizmo[0] != tims_gizmo[0] diff --git a/test/async_/test_cardinality.py b/test/async_/test_cardinality.py new file mode 100644 index 00000000..4ce02ad4 --- /dev/null +++ b/test/async_/test_cardinality.py @@ -0,0 +1,187 @@ +from test._async_compat import mark_async_test + +from pytest import raises + +from neomodel import ( + AsyncOne, + AsyncOneOrMore, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncZeroOrMore, + AsyncZeroOrOne, + AttemptedCardinalityViolation, + CardinalityViolation, + IntegerProperty, + StringProperty, + adb, +) + + +class HairDryer(AsyncStructuredNode): + version = IntegerProperty() + + +class ScrewDriver(AsyncStructuredNode): + version = IntegerProperty() + + +class Car(AsyncStructuredNode): + version = IntegerProperty() + + +class Monkey(AsyncStructuredNode): + name = StringProperty() + dryers = AsyncRelationshipTo("HairDryer", "OWNS_DRYER", cardinality=AsyncZeroOrMore) + driver = AsyncRelationshipTo( + "ScrewDriver", "HAS_SCREWDRIVER", cardinality=AsyncZeroOrOne + ) + car = AsyncRelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore) + toothbrush = AsyncRelationshipTo( + "ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne + ) + + +class ToothBrush(AsyncStructuredNode): + name = StringProperty() + + +@mark_async_test +async def test_cardinality_zero_or_more(): + m = await Monkey(name="tim").save() + assert await m.dryers.all() == [] + single_dryer = await m.driver.single() + assert single_dryer is None + h = await HairDryer(version=1).save() + + await m.dryers.connect(h) + assert len(await m.dryers.all()) == 1 + single_dryer = await m.dryers.single() + assert single_dryer.version == 1 + + await m.dryers.disconnect(h) + assert await m.dryers.all() == [] + single_dryer = await m.driver.single() + assert single_dryer is None + + h2 = await HairDryer(version=2).save() + await m.dryers.connect(h) + await m.dryers.connect(h2) + await m.dryers.disconnect_all() + assert await m.dryers.all() == [] + single_dryer = await m.driver.single() + assert single_dryer is None + + +@mark_async_test +async def test_cardinality_zero_or_one(): + m = await Monkey(name="bob").save() + assert await m.driver.all() == [] + single_driver = await m.driver.single() + assert await m.driver.single() is None + h = await ScrewDriver(version=1).save() + + await m.driver.connect(h) + assert len(await m.driver.all()) == 1 + single_driver = await m.driver.single() + assert single_driver.version == 1 + + j = await ScrewDriver(version=2).save() + with raises(AttemptedCardinalityViolation): + await m.driver.connect(j) + + await m.driver.reconnect(h, j) + single_driver = await m.driver.single() + assert single_driver.version == 2 + + # Forcing creation of a second ToothBrush to go around + # AttemptedCardinalityViolation + await adb.cypher_query( + """ + MATCH (m:Monkey WHERE m.name="bob") + CREATE (s:ScrewDriver {version:3}) + WITH m, s + CREATE (m)-[:HAS_SCREWDRIVER]->(s) + """ + ) + with raises( + CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: 2." + ): + await m.driver.all() + + +@mark_async_test +async def test_cardinality_one_or_more(): + m = await Monkey(name="jerry").save() + + with raises(CardinalityViolation): + await m.car.all() + + with raises(CardinalityViolation): + await m.car.single() + + c = await Car(version=2).save() + await m.car.connect(c) + single_car = await m.car.single() + assert single_car.version == 2 + + cars = await m.car.all() + assert len(cars) == 1 + + with raises(AttemptedCardinalityViolation): + await m.car.disconnect(c) + + d = await Car(version=3).save() + await m.car.connect(d) + cars = await m.car.all() + assert len(cars) == 2 + + await m.car.disconnect(d) + cars = await m.car.all() + assert len(cars) == 1 + + +@mark_async_test +async def test_cardinality_one(): + m = await Monkey(name="jerry").save() + + with raises( + CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: none." + ): + await m.toothbrush.all() + + with raises(CardinalityViolation): + await m.toothbrush.single() + + b = await ToothBrush(name="Jim").save() + await m.toothbrush.connect(b) + single_toothbrush = await m.toothbrush.single() + assert single_toothbrush.name == "Jim" + + x = await ToothBrush(name="Jim").save() + with raises(AttemptedCardinalityViolation): + await m.toothbrush.connect(x) + + with raises(AttemptedCardinalityViolation): + await m.toothbrush.disconnect(b) + + with raises(AttemptedCardinalityViolation): + await m.toothbrush.disconnect_all() + + # Forcing creation of a second ToothBrush to go around + # AttemptedCardinalityViolation + await adb.cypher_query( + """ + MATCH (m:Monkey WHERE m.name="jerry") + CREATE (t:ToothBrush {name:"Jim"}) + WITH m, t + CREATE (m)-[:HAS_TOOTHBRUSH]->(t) + """ + ) + with raises( + CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: 2." + ): + await m.toothbrush.all() + + jp = Monkey(name="Jean-Pierre") + with raises(ValueError, match="Node has not been saved cannot connect!"): + await jp.toothbrush.connect(b) diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py new file mode 100644 index 00000000..a2eded7d --- /dev/null +++ b/test/async_/test_connection.py @@ -0,0 +1,148 @@ +import os +from test._async_compat import mark_async_test +from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME + +import pytest +from neo4j import AsyncDriver, AsyncGraphDatabase +from neo4j.debug import watch + +from neomodel import AsyncStructuredNode, StringProperty, adb, config + + +@mark_async_test +@pytest.fixture(autouse=True) +async def setup_teardown(): + yield + # Teardown actions after tests have run + # Reconnect to initial URL for potential subsequent tests + await adb.close_connection() + await adb.set_connection(url=config.DATABASE_URL) + + +@pytest.fixture(autouse=True, scope="session") +def neo4j_logging(): + with watch("neo4j"): + yield + + +@mark_async_test +async def get_current_database_name() -> str: + """ + Fetches the name of the currently active database from the Neo4j database. + + Returns: + - str: The name of the current database. + """ + results, meta = await adb.cypher_query("CALL db.info") + results_as_dict = [dict(zip(meta, row)) for row in results] + + return results_as_dict[0]["name"] + + +class Pastry(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + +@mark_async_test +async def test_set_connection_driver_works(): + # Verify that current connection is up + assert await Pastry(name="Chocolatine").save() + await adb.close_connection() + + # Test connection using a driver + await adb.set_connection( + driver=AsyncGraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) + ) + assert await Pastry(name="Croissant").save() + + +@mark_async_test +async def test_config_driver_works(): + # Verify that current connection is up + assert await Pastry(name="Chausson aux pommes").save() + await adb.close_connection() + + # Test connection using a driver defined in config + driver: AsyncDriver = AsyncGraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) + + config.DRIVER = driver + assert await Pastry(name="Grignette").save() + + # Clear config + # No need to close connection - pytest teardown will do it + config.DRIVER = None + + +@mark_async_test +async def test_connect_to_non_default_database(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition - no multi database in CE") + database_name = "pastries" + await adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") + await adb.close_connection() + + # Set database name in url - for url init only + await adb.set_connection(url=f"{config.DATABASE_URL}/{database_name}") + assert await get_current_database_name() == "pastries" + + await adb.close_connection() + + # Set database name in config - for both url and driver init + config.DATABASE_NAME = database_name + + # url init + await adb.set_connection(url=config.DATABASE_URL) + assert await get_current_database_name() == "pastries" + + await adb.close_connection() + + # driver init + await adb.set_connection( + driver=AsyncGraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) + ) + assert await get_current_database_name() == "pastries" + + # Clear config + # No need to close connection - pytest teardown will do it + config.DATABASE_NAME = None + + +@mark_async_test +@pytest.mark.parametrize( + "url", ["bolt://user:password", "http://user:password@localhost:7687"] +) +async def test_wrong_url_format(url): + with pytest.raises( + ValueError, + match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}", + ): + await adb.set_connection(url=url) + + +@mark_async_test +@pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) +async def test_connect_to_aura(protocol): + cypher_return = "hello world" + default_cypher_query = f"RETURN '{cypher_return}'" + await adb.close_connection() + + await _set_connection(protocol=protocol) + result, _ = await adb.cypher_query(default_cypher_query) + + assert len(result) > 0 + assert result[0][0] == cypher_return + + +async def _set_connection(protocol): + AURA_TEST_DB_USER = os.environ["AURA_TEST_DB_USER"] + AURA_TEST_DB_PASSWORD = os.environ["AURA_TEST_DB_PASSWORD"] + AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"] + + database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}" + await adb.set_connection(url=database_url) diff --git a/test/async_/test_contrib/__init__.py b/test/async_/test_contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/async_/test_contrib/test_semi_structured.py b/test/async_/test_contrib/test_semi_structured.py new file mode 100644 index 00000000..3b88fcad --- /dev/null +++ b/test/async_/test_contrib/test_semi_structured.py @@ -0,0 +1,35 @@ +from test._async_compat import mark_async_test + +from neomodel import IntegerProperty, StringProperty +from neomodel.contrib import AsyncSemiStructuredNode + + +class UserProf(AsyncSemiStructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +class Dummy(AsyncSemiStructuredNode): + pass + + +@mark_async_test +async def test_to_save_to_model_with_required_only(): + u = UserProf(email="dummy@test.com") + assert await u.save() + + +@mark_async_test +async def test_save_to_model_with_extras(): + u = UserProf(email="jim@test.com", age=3, bar=99) + u.foo = True + assert await u.save() + u = await UserProf.nodes.get(age=3) + assert u.foo is True + assert u.bar == 99 + + +@mark_async_test +async def test_save_empty_model(): + dummy = Dummy() + assert await dummy.save() diff --git a/test/test_contrib/test_spatial_datatypes.py b/test/async_/test_contrib/test_spatial_datatypes.py similarity index 100% rename from test/test_contrib/test_spatial_datatypes.py rename to test/async_/test_contrib/test_spatial_datatypes.py diff --git a/test/async_/test_contrib/test_spatial_properties.py b/test/async_/test_contrib/test_spatial_properties.py new file mode 100644 index 00000000..a6103639 --- /dev/null +++ b/test/async_/test_contrib/test_spatial_properties.py @@ -0,0 +1,291 @@ +""" +Provides a test case for issue 374 - "Support for Point property type". + +For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374 +""" + +import random +from test._async_compat import mark_async_test + +import neo4j.spatial +import pytest + +import neomodel +import neomodel.contrib.spatial_properties + +from .test_spatial_datatypes import ( + basic_type_assertions, + check_and_skip_neo4j_least_version, +) + + +def test_spatial_point_property(): + """ + Tests that specific modes of instantiation fail as expected. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + with pytest.raises(ValueError, match=r"Invalid CRS\(None\)"): + a_point_property = neomodel.contrib.spatial_properties.PointProperty() + + with pytest.raises(ValueError, match=r"Invalid CRS\(crs_isaak\)"): + a_point_property = neomodel.contrib.spatial_properties.PointProperty( + crs="crs_isaak" + ) + + with pytest.raises(TypeError, match="Invalid default value"): + a_point_property = neomodel.contrib.spatial_properties.PointProperty( + default=(0.0, 0.0), crs="cartesian" + ) + + +def test_inflate(): + """ + Tests that the marshalling from neo4j to neomodel data types works as expected. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # The test is repeatable enough to try and standardise it. The same test is repeated with the assertions in + # `basic_type_assertions` and different messages to be able to localise the exception. + # + # Array of points to inflate and messages when things go wrong + values_from_db = [ + ( + neo4j.spatial.CartesianPoint((0.0, 0.0)), + "Expected Neomodel 2d cartesian point when inflating 2d cartesian neo4j point", + ), + ( + neo4j.spatial.CartesianPoint((0.0, 0.0, 0.0)), + "Expected Neomodel 3d cartesian point when inflating 3d cartesian neo4j point", + ), + ( + neo4j.spatial.WGS84Point((0.0, 0.0)), + "Expected Neomodel 2d geographical point when inflating 2d geographical neo4j point", + ), + ( + neo4j.spatial.WGS84Point((0.0, 0.0, 0.0)), + "Expected Neomodel 3d geographical point inflating 3d geographical neo4j point", + ), + ] + + # Run the above tests + for a_value in values_from_db: + expected_point = neomodel.contrib.spatial_properties.NeomodelPoint( + tuple(a_value[0]), + crs=neomodel.contrib.spatial_properties.SRID_TO_CRS[a_value[0].srid], + ) + inflated_point = neomodel.contrib.spatial_properties.PointProperty( + crs=neomodel.contrib.spatial_properties.SRID_TO_CRS[a_value[0].srid] + ).inflate(a_value[0]) + basic_type_assertions( + expected_point, + inflated_point, + "{}, received {}".format(a_value[1], inflated_point), + ) + + +def test_deflate(): + """ + Tests that the marshalling from neomodel to neo4j data types works as expected + :return: + """ + # Please see inline comments in `test_inflate`. This test function is 90% to that one with very minor differences. + # + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + CRS_TO_SRID = dict( + [ + (value, key) + for key, value in neomodel.contrib.spatial_properties.SRID_TO_CRS.items() + ] + ) + # Values to construct and expect during deflation + values_from_neomodel = [ + ( + neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="cartesian" + ), + "Expected Neo4J 2d cartesian point when deflating Neomodel 2d cartesian point", + ), + ( + neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="cartesian-3d" + ), + "Expected Neo4J 3d cartesian point when deflating Neomodel 3d cartesian point", + ), + ( + neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0), crs="wgs-84"), + "Expected Neo4J 2d geographical point when deflating Neomodel 2d geographical point", + ), + ( + neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ), + "Expected Neo4J 3d geographical point when deflating Neomodel 3d geographical point", + ), + ] + + # Run the above tests. + for a_value in values_from_neomodel: + expected_point = neo4j.spatial.Point(tuple(a_value[0].coords[0])) + expected_point.srid = CRS_TO_SRID[a_value[0].crs] + deflated_point = neomodel.contrib.spatial_properties.PointProperty( + crs=a_value[0].crs + ).deflate(a_value[0]) + basic_type_assertions( + expected_point, + deflated_point, + "{}, received {}".format(a_value[1], deflated_point), + check_neo4j_points=True, + ) + + +@mark_async_test +async def test_default_value(): + """ + Tests that the default value passing mechanism works as expected with NeomodelPoint values. + :return: + """ + + def get_some_point(): + return neomodel.contrib.spatial_properties.NeomodelPoint( + (random.random(), random.random()) + ) + + class LocalisableEntity(neomodel.AsyncStructuredNode): + """ + A very simple entity to try out the default value assignment. + """ + + identifier = neomodel.UniqueIdProperty() + location = neomodel.contrib.spatial_properties.PointProperty( + crs="cartesian", default=get_some_point + ) + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Save an object + an_object = await LocalisableEntity().save() + coords = an_object.location.coords[0] + # Retrieve it + retrieved_object = await LocalisableEntity.nodes.get( + identifier=an_object.identifier + ) + # Check against an independently created value + assert ( + retrieved_object.location + == neomodel.contrib.spatial_properties.NeomodelPoint(coords) + ), ("Default value assignment failed.") + + +@mark_async_test +async def test_array_of_points(): + """ + Tests that Arrays of Points work as expected. + + :return: + """ + + class AnotherLocalisableEntity(neomodel.AsyncStructuredNode): + """ + A very simple entity with an array of locations + """ + + identifier = neomodel.UniqueIdProperty() + locations = neomodel.ArrayProperty( + neomodel.contrib.spatial_properties.PointProperty(crs="cartesian") + ) + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + an_object = await AnotherLocalisableEntity( + locations=[ + neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), + neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), + ] + ).save() + + retrieved_object = await AnotherLocalisableEntity.nodes.get( + identifier=an_object.identifier + ) + + assert ( + type(retrieved_object.locations) is list + ), "Array of Points definition failed." + assert retrieved_object.locations == [ + neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)), + neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)), + ], "Array of Points incorrect values." + + +@mark_async_test +async def test_simple_storage_retrieval(): + """ + Performs a simple Create, Retrieve via .save(), .get() which, due to the way Q objects operate, tests the + __copy__, __deepcopy__ operations of NeomodelPoint. + :return: + """ + + class TestStorageRetrievalProperty(neomodel.AsyncStructuredNode): + uid = neomodel.UniqueIdProperty() + description = neomodel.StringProperty() + location = neomodel.contrib.spatial_properties.PointProperty(crs="cartesian") + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + a_restaurant = await TestStorageRetrievalProperty( + description="Milliways", + location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)), + ).save() + + a_property = await TestStorageRetrievalProperty.nodes.get( + location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)) + ) + + assert a_restaurant.description == a_property.description + + +def test_equality_with_other_objects(): + """ + Performs equality tests and ensures tha ``NeomodelPoint`` can be compared with ShapelyPoint and NeomodelPoint only. + """ + try: + import shapely.geometry + from shapely import __version__ + except ImportError: + pytest.skip("Shapely module not present") + + if int("".join(__version__.split(".")[0:3])) < 200: + pytest.skip(f"Shapely 2.0 not present (Current version is {__version__}") + + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0) + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == shapely.geometry.Point((0, 0)) diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py new file mode 100644 index 00000000..c078c8d5 --- /dev/null +++ b/test/async_/test_cypher.py @@ -0,0 +1,165 @@ +import builtins +from test._async_compat import mark_async_test + +import pytest +from neo4j.exceptions import ClientError as CypherError +from numpy import ndarray +from pandas import DataFrame, Series + +from neomodel import AsyncStructuredNode, StringProperty, adb +from neomodel._async_compat.util import AsyncUtil + + +class User2(AsyncStructuredNode): + name = StringProperty() + email = StringProperty() + + +class UserPandas(AsyncStructuredNode): + name = StringProperty() + email = StringProperty() + + +class UserNP(AsyncStructuredNode): + name = StringProperty() + email = StringProperty() + + +@pytest.fixture +def hide_available_pkg(monkeypatch, request): + import_orig = builtins.__import__ + + def mocked_import(name, *args, **kwargs): + if name == request.param: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + +@mark_async_test +async def test_cypher(): + """ + test result format is backward compatible with earlier versions of neomodel + """ + + jim = await User2(email="jim1@test.com").save() + data, meta = await jim.cypher( + f"MATCH (a) WHERE {await adb.get_id_method()}(a)=$self RETURN a.email" + ) + assert data[0][0] == "jim1@test.com" + assert "a.email" in meta + + data, meta = await jim.cypher( + f""" + MATCH (a) WHERE {await adb.get_id_method()}(a)=$self + MATCH (a)<-[:USER2]-(b) + RETURN a, b, 3 + """ + ) + assert "a" in meta and "b" in meta + + +@mark_async_test +async def test_cypher_syntax_error(): + jim = await User2(email="jim1@test.com").save() + try: + await jim.cypher( + f"MATCH a WHERE {await adb.get_id_method()}(a)={{self}} RETURN xx" + ) + except CypherError as e: + assert hasattr(e, "message") + assert hasattr(e, "code") + else: + assert False, "CypherError not raised." + + +@mark_async_test +@pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) +async def test_pandas_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when pandas.py is imported + if not AsyncUtil.is_async_code: + pytest.skip("This test is async only") + with pytest.raises(ImportError): + with pytest.warns( + UserWarning, + match="The neomodel.integration.pandas module expects pandas to be installed", + ): + from neomodel.integration.pandas import to_dataframe + + _ = to_dataframe(await adb.cypher_query("MATCH (a) RETURN a.name AS name")) + + +@mark_async_test +async def test_pandas_integration(): + from neomodel.integration.pandas import to_dataframe, to_series + + jimla = await UserPandas(email="jimla@test.com", name="jimla").save() + jimlo = await UserPandas(email="jimlo@test.com", name="jimlo").save() + + # Test to_dataframe + df = to_dataframe( + await adb.cypher_query( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + ) + ) + + assert isinstance(df, DataFrame) + assert df.shape == (2, 2) + assert df["name"].tolist() == ["jimla", "jimlo"] + + # Also test passing an index and dtype to to_dataframe + df = to_dataframe( + await adb.cypher_query( + "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email" + ), + index=df["email"], + dtype=str, + ) + + assert df.index.inferred_type == "string" + + # Next test to_series + series = to_series( + await adb.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name") + ) + + assert isinstance(series, Series) + assert series.shape == (2,) + assert df["name"].tolist() == ["jimla", "jimlo"] + + +@mark_async_test +@pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) +async def test_numpy_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when numpy.py is imported + if not AsyncUtil.is_async_code: + pytest.skip("This test is async only") + with pytest.raises(ImportError): + with pytest.warns( + UserWarning, + match="The neomodel.integration.numpy module expects numpy to be installed", + ): + from neomodel.integration.numpy import to_ndarray + + _ = to_ndarray(await adb.cypher_query("MATCH (a) RETURN a.name AS name")) + + +@mark_async_test +async def test_numpy_integration(): + from neomodel.integration.numpy import to_ndarray + + jimly = await UserNP(email="jimly@test.com", name="jimly").save() + jimlu = await UserNP(email="jimlu@test.com", name="jimlu").save() + + array = to_ndarray( + await adb.cypher_query( + "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email ORDER BY name" + ) + ) + + assert isinstance(array, ndarray) + assert array.shape == (2, 2) + assert array[0][0] == "jimlu" diff --git a/test/async_/test_database_management.py b/test/async_/test_database_management.py new file mode 100644 index 00000000..5159642a --- /dev/null +++ b/test/async_/test_database_management.py @@ -0,0 +1,81 @@ +from test._async_compat import mark_async_test + +import pytest +from neo4j.exceptions import AuthError + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, + adb, +) + + +class City(AsyncStructuredNode): + name = StringProperty() + + +class InCity(AsyncStructuredRel): + creation_year = IntegerProperty(index=True) + + +class Venue(AsyncStructuredNode): + name = StringProperty(unique_index=True) + creator = StringProperty(index=True) + in_city = AsyncRelationshipTo(City, relation_type="IN", model=InCity) + + +@mark_async_test +async def test_clear_database(): + venue = await Venue(name="Royal Albert Hall", creator="Queen Victoria").save() + city = await City(name="London").save() + await venue.in_city.connect(city) + + # Clear only the data + await adb.clear_neo4j_database() + database_is_populated, _ = await adb.cypher_query( + "MATCH (a) return count(a)>0 as database_is_populated" + ) + + assert database_is_populated[0][0] is False + + await adb.install_all_labels() + indexes = await adb.list_indexes(exclude_token_lookup=True) + constraints = await adb.list_constraints() + assert len(indexes) > 0 + assert len(constraints) > 0 + + # Clear constraints and indexes too + await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + + indexes = await adb.list_indexes(exclude_token_lookup=True) + constraints = await adb.list_constraints() + assert len(indexes) == 0 + assert len(constraints) == 0 + + +@mark_async_test +async def test_change_password(): + prev_password = "foobarbaz" + new_password = "newpassword" + prev_url = f"bolt://neo4j:{prev_password}@localhost:7687" + new_url = f"bolt://neo4j:{new_password}@localhost:7687" + + await adb.change_neo4j_password("neo4j", new_password) + await adb.close_connection() + + await adb.set_connection(url=new_url) + await adb.close_connection() + + with pytest.raises(AuthError): + await adb.set_connection(url=prev_url) + + await adb.close_connection() + + await adb.set_connection(url=new_url) + await adb.change_neo4j_password("neo4j", prev_password) + await adb.close_connection() + + await adb.set_connection(url=prev_url) diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py new file mode 100644 index 00000000..66a8fc5f --- /dev/null +++ b/test/async_/test_dbms_awareness.py @@ -0,0 +1,37 @@ +from test._async_compat import mark_async_test + +import pytest + +from neomodel import adb +from neomodel.util import version_tag_to_integer + + +@mark_async_test +async def test_version_awareness(): + db_version = await adb.database_version + if db_version != "5.7.0": + pytest.skip("Testing a specific database version") + assert db_version == "5.7.0" + assert await adb.version_is_higher_than("5.7") + assert await adb.version_is_higher_than("5.6.0") + assert await adb.version_is_higher_than("5") + assert await adb.version_is_higher_than("4") + + assert not await adb.version_is_higher_than("5.8") + + +@mark_async_test +async def test_edition_awareness(): + db_edition = await adb.database_edition + if db_edition == "enterprise": + assert await adb.edition_is_enterprise() + else: + assert not await adb.edition_is_enterprise() + + +def test_version_tag_to_integer(): + assert version_tag_to_integer("5.7.1") == 50701 + assert version_tag_to_integer("5.1") == 50100 + assert version_tag_to_integer("5") == 50000 + assert version_tag_to_integer("5.14.1") == 51401 + assert version_tag_to_integer("5.14-aura") == 51400 diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py new file mode 100644 index 00000000..df378f98 --- /dev/null +++ b/test/async_/test_driver_options.py @@ -0,0 +1,52 @@ +from test._async_compat import mark_async_test + +import pytest +from neo4j.exceptions import ClientError +from pytest import raises + +from neomodel import adb +from neomodel.exceptions import FeatureNotSupported + + +@mark_async_test +async def test_impersonate(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") + with await adb.impersonate(user="troygreene"): + results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + +@mark_async_test +async def test_impersonate_unauthorized(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") + with await adb.impersonate(user="unknownuser"): + with raises(ClientError): + _ = await adb.cypher_query("RETURN 'Gabagool'") + + +@mark_async_test +async def test_impersonate_multiple_transactions(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") + with await adb.impersonate(user="troygreene"): + async with adb.transaction: + results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'") + assert results[0][0] == "Doo Wacko !" + + async with adb.transaction: + results, _ = await adb.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "troygreene" + + results, _ = await adb.cypher_query("SHOW CURRENT USER") + assert results[0][0] == "neo4j" + + +@mark_async_test +async def test_impersonate_community(): + if await adb.edition_is_enterprise(): + pytest.skip("Skipping test for enterprise edition") + with raises(FeatureNotSupported): + with await adb.impersonate(user="troygreene"): + _ = await adb.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/async_/test_exceptions.py b/test/async_/test_exceptions.py new file mode 100644 index 00000000..e948f76c --- /dev/null +++ b/test/async_/test_exceptions.py @@ -0,0 +1,31 @@ +import pickle +from test._async_compat import mark_async_test + +from neomodel import AsyncStructuredNode, DoesNotExist, StringProperty + + +class EPerson(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + +@mark_async_test +async def test_object_does_not_exist(): + try: + await EPerson.nodes.get(name="johnny") + except EPerson.DoesNotExist as e: + pickle_instance = pickle.dumps(e) + assert pickle_instance + assert pickle.loads(pickle_instance) + assert isinstance(pickle.loads(pickle_instance), DoesNotExist) + else: + assert False, "Person.DoesNotExist not raised." + + +def test_pickle_does_not_exist(): + try: + raise EPerson.DoesNotExist("My Test Message") + except EPerson.DoesNotExist as e: + pickle_instance = pickle.dumps(e) + assert pickle_instance + assert pickle.loads(pickle_instance) + assert isinstance(pickle.loads(pickle_instance), DoesNotExist) diff --git a/test/async_/test_hooks.py b/test/async_/test_hooks.py new file mode 100644 index 00000000..09a87403 --- /dev/null +++ b/test/async_/test_hooks.py @@ -0,0 +1,35 @@ +from test._async_compat import mark_async_test + +from neomodel import AsyncStructuredNode, StringProperty + +HOOKS_CALLED = {} + + +class HookTest(AsyncStructuredNode): + name = StringProperty() + + def post_create(self): + HOOKS_CALLED["post_create"] = 1 + + def pre_save(self): + HOOKS_CALLED["pre_save"] = 1 + + def post_save(self): + HOOKS_CALLED["post_save"] = 1 + + def pre_delete(self): + HOOKS_CALLED["pre_delete"] = 1 + + def post_delete(self): + HOOKS_CALLED["post_delete"] = 1 + + +@mark_async_test +async def test_hooks(): + ht = await HookTest(name="k").save() + await ht.delete() + assert "pre_save" in HOOKS_CALLED + assert "post_save" in HOOKS_CALLED + assert "post_create" in HOOKS_CALLED + assert "pre_delete" in HOOKS_CALLED + assert "post_delete" in HOOKS_CALLED diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py new file mode 100644 index 00000000..9e0d8f37 --- /dev/null +++ b/test/async_/test_indexing.py @@ -0,0 +1,88 @@ +from test._async_compat import mark_async_test + +import pytest +from pytest import raises + +from neomodel import ( + AsyncStructuredNode, + IntegerProperty, + StringProperty, + UniqueProperty, + adb, +) +from neomodel.exceptions import ConstraintValidationFailed + + +class Human(AsyncStructuredNode): + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True) + + +@mark_async_test +async def test_unique_error(): + await adb.install_labels(Human) + await Human(name="j1m", age=13).save() + try: + await Human(name="j1m", age=14).save() + except UniqueProperty as e: + assert str(e).find("j1m") + assert str(e).find("name") + else: + assert False, "UniqueProperty not raised." + + +@mark_async_test +async def test_existence_constraint_error(): + if not await adb.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") + await adb.cypher_query( + "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" + ) + with raises(ConstraintValidationFailed, match=r"must have the property"): + await Human(name="Scarlett").save() + + await adb.cypher_query("DROP CONSTRAINT test_existence_constraint") + + +@mark_async_test +async def test_optional_properties_dont_get_indexed(): + await Human(name="99", age=99).save() + h = await Human.nodes.get(age=99) + assert h + assert h.name == "99" + + await Human(name="98", age=98).save() + h = await Human.nodes.get(age=98) + assert h + assert h.name == "98" + + +@mark_async_test +async def test_escaped_chars(): + _name = "sarah:test" + await Human(name=_name, age=3).save() + r = await Human.nodes.filter(name=_name) + assert r[0].name == _name + + +@mark_async_test +async def test_does_not_exist(): + with raises(Human.DoesNotExist): + await Human.nodes.get(name="XXXX") + + +@mark_async_test +async def test_custom_label_name(): + class Giraffe(AsyncStructuredNode): + __label__ = "Giraffes" + name = StringProperty(unique_index=True) + + jim = await Giraffe(name="timothy").save() + node = await Giraffe.nodes.get(name="timothy") + assert node.name == jim.name + + class SpecialGiraffe(Giraffe): + power = StringProperty() + + # custom labels aren't inherited + assert SpecialGiraffe.__label__ == "SpecialGiraffe" diff --git a/test/async_/test_issue112.py b/test/async_/test_issue112.py new file mode 100644 index 00000000..8ba1ce03 --- /dev/null +++ b/test/async_/test_issue112.py @@ -0,0 +1,19 @@ +from test._async_compat import mark_async_test + +from neomodel import AsyncRelationshipTo, AsyncStructuredNode + + +class SomeModel(AsyncStructuredNode): + test = AsyncRelationshipTo("SomeModel", "SELF") + + +@mark_async_test +async def test_len_relationship(): + t1 = await SomeModel().save() + t2 = await SomeModel().save() + + await t1.test.connect(t2) + l = len(await t1.test.all()) + + assert l + assert l == 1 diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py new file mode 100644 index 00000000..8106a796 --- /dev/null +++ b/test/async_/test_issue283.py @@ -0,0 +1,524 @@ +""" +Provides a test case for issue 283 - "Inheritance breaks". + +The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/283 +More information about the same issue at: +https://github.com/aanastasiou/neomodelInheritanceTest + +The following example uses a recursive relationship for economy, but the +idea remains the same: "Instantiate the correct type of node at the end of +a relationship as specified by the model" +""" +import random +from test._async_compat import mark_async_test + +import pytest + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + DateTimeProperty, + FloatProperty, + RelationshipClassNotDefined, + RelationshipClassRedefined, + StringProperty, + UniqueIdProperty, + adb, +) +from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined + +try: + basestring +except NameError: + basestring = str + + +# Set up a very simple model for the tests +class PersonalRelationship(AsyncStructuredRel): + """ + A very simple relationship between two basePersons that simply records + the date at which an acquaintance was established. + This relationship should be carried over to anything that inherits from + basePerson without any further effort. + """ + + on_date = DateTimeProperty(default_now=True) + + +class BasePerson(AsyncStructuredNode): + """ + Base class for defining some basic sort of an actor. + """ + + name = StringProperty(required=True, unique_index=True) + friends_with = AsyncRelationshipTo( + "BasePerson", "FRIENDS_WITH", model=PersonalRelationship + ) + + +class TechnicalPerson(BasePerson): + """ + A Technical person specialises BasePerson by adding their expertise. + """ + + expertise = StringProperty(required=True) + + +class PilotPerson(BasePerson): + """ + A pilot person specialises BasePerson by adding the type of airplane they + can operate. + """ + + airplane = StringProperty(required=True) + + +class BaseOtherPerson(AsyncStructuredNode): + """ + An obviously "wrong" class of actor to befriend BasePersons with. + """ + + car_color = StringProperty(required=True) + + +class SomePerson(BaseOtherPerson): + """ + Concrete class that simply derives from BaseOtherPerson. + """ + + pass + + +# Test cases +@mark_async_test +async def test_automatic_result_resolution(): + """ + Node objects at the end of relationships are instantiated to their + corresponding Python object. + """ + + # Create a few entities + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] + + # Add connections + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(A) + + test = await A.friends_with + + # If A is friends with B, then A's friends_with objects should be + # TechnicalPerson (!NOT basePerson!) + assert type((await A.friends_with)[0]) is TechnicalPerson + + await A.delete() + await B.delete() + await C.delete() + + +@mark_async_test +async def test_recursive_automatic_result_resolution(): + """ + Node objects are instantiated to native Python objects, both at the top + level of returned results and in the case where they are returned within + lists. + """ + + # Create a few entities + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpier", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create( + {"name": "Happier", "expertise": "Grumpiness"} + ) + )[0] + C = ( + await TechnicalPerson.get_or_create( + {"name": "Sleepier", "expertise": "Pillows"} + ) + )[0] + D = ( + await TechnicalPerson.get_or_create( + {"name": "Sneezier", "expertise": "Pillows"} + ) + )[0] + + # Retrieve mixed results, both at the top level and nested + L, _ = await adb.cypher_query( + "MATCH (a:TechnicalPerson) " + "WHERE a.expertise='Grumpiness' " + "WITH collect(a) as Alpha " + "MATCH (b:TechnicalPerson) " + "WHERE b.expertise='Pillows' " + "WITH Alpha, collect(b) as Beta " + "RETURN [Alpha, [Beta, [Beta, ['Banana', " + "Alpha]]]]", + resolve_objects=True, + ) + + # Assert that a Node returned deep in a nested list structure is of the + # correct type + assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson + # Assert that primitive data types remain primitive data types + assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring) + + await A.delete() + await B.delete() + await C.delete() + await D.delete() + + +@mark_async_test +async def test_validation_with_inheritance_from_db(): + """ + Objects descending from the specified class of a relationship's end-node are + also perfectly valid to appear as end-node values too + """ + + # Create a few entities + # Technical Persons + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] + + # Pilot Persons + D = ( + await PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) + )[0] + E = ( + await PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) + )[0] + + # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine + + # TechnicalPersons befriend Technical Persons + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(A) + + # Pilot Persons befriend Pilot Persons + await D.friends_with.connect(E) + + # Technical Persons befriend Pilot Persons + await A.friends_with.connect(D) + await E.friends_with.connect(C) + + # This now means that friends_with of a TechnicalPerson can + # either be TechnicalPerson or Pilot Person (!NOT basePerson!) + + assert (type((await A.friends_with)[0]) is TechnicalPerson) or ( + type((await A.friends_with)[0]) is PilotPerson + ) + assert (type((await A.friends_with)[1]) is TechnicalPerson) or ( + type((await A.friends_with)[1]) is PilotPerson + ) + assert type((await D.friends_with)[0]) is PilotPerson + + await A.delete() + await B.delete() + await C.delete() + await D.delete() + await E.delete() + + +@mark_async_test +async def test_validation_enforcement_to_db(): + """ + If a connection between wrong types is attempted, raise an exception + """ + + # Create a few entities + # Technical Persons + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + B = ( + await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}) + )[0] + C = ( + await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}) + )[0] + + # Pilot Persons + D = ( + await PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) + )[0] + E = ( + await PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) + )[0] + + # Some Person + F = await SomePerson(car_color="Blue").save() + + # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(A) + await D.friends_with.connect(E) + await A.friends_with.connect(D) + await E.friends_with.connect(C) + + # Trying to befriend a Technical Person with Some Person should raise an + # exception + with pytest.raises(ValueError): + await A.friends_with.connect(F) + + await A.delete() + await B.delete() + await C.delete() + await D.delete() + await E.delete() + await F.delete() + + +@mark_async_test +async def test_failed_result_resolution(): + """ + A Neo4j driver node FROM the database contains labels that are unaware to + neomodel's Database class. This condition raises ClassDefinitionNotFound + exception. + """ + + class RandomPerson(BasePerson): + randomness = FloatProperty(default=random.random) + + # A Technical Person... + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + + # A Random Person... + B = (await RandomPerson.get_or_create({"name": "Mad Hatter"}))[0] + + await A.friends_with.connect(B) + + # Simulate the condition where the definition of class RandomPerson is not + # known yet. + del adb._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] + + # Now try to instantiate a RandomPerson + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + with pytest.raises( + NodeClassNotDefined, + match=r"Node with labels .* does not resolve to any of the known objects.*", + ): + friends = await A.friends_with.all() + for some_friend in friends: + print(some_friend.name) + + await A.delete() + await B.delete() + + +@mark_async_test +async def test_node_label_mismatch(): + """ + A Neo4j driver node FROM the database contains a superset of the known + labels. + """ + + class SuperTechnicalPerson(TechnicalPerson): + superness = FloatProperty(default=1.0) + + class UltraTechnicalPerson(SuperTechnicalPerson): + ultraness = FloatProperty(default=3.1415928) + + # Create a TechnicalPerson... + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + # ...that is connected to an UltraTechnicalPerson + F = await UltraTechnicalPerson( + name="Chewbaka", expertise="Aarrr wgh ggwaaah" + ).save() + await A.friends_with.connect(F) + + # Forget about the UltraTechnicalPerson + del adb._NODE_CLASS_REGISTRY[ + frozenset( + [ + "UltraTechnicalPerson", + "SuperTechnicalPerson", + "TechnicalPerson", + "BasePerson", + ] + ) + ] + + # Recall a TechnicalPerson and enumerate its friends. + # One of them is UltraTechnicalPerson which would be returned as a valid + # node to a friends_with query but is currently unknown to the node class registry. + A = ( + await TechnicalPerson.get_or_create( + {"name": "Grumpy", "expertise": "Grumpiness"} + ) + )[0] + with pytest.raises(NodeClassNotDefined): + friends = await A.friends_with.all() + for some_friend in friends: + print(some_friend.name) + + +def test_attempted_class_redefinition(): + """ + A StructuredNode class is attempted to be redefined. + """ + + def redefine_class_locally(): + # Since this test has already set up a class hierarchy in its global scope, we will try to redefine + # SomePerson here. + # The internal structure of the SomePerson entity does not matter at all here. + class SomePerson(BaseOtherPerson): + uid = UniqueIdProperty() + + with pytest.raises( + NodeClassAlreadyDefined, + match=r"Class .* with labels .* already defined:.*", + ): + redefine_class_locally() + + +@mark_async_test +async def test_relationship_result_resolution(): + """ + A query returning a "Relationship" object can now instantiate it to a data model class + """ + # Test specific data + A = await PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save() + B = await PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save() + C = await PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save() + D = await PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save() + E = await PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save() + + await A.friends_with.connect(B) + await B.friends_with.connect(C) + await C.friends_with.connect(D) + await D.friends_with.connect(E) + + query_data = await adb.cypher_query( + "MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) " + "WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' " + "RETURN DISTINCT r", + resolve_objects=True, + ) + + # The relationship here should be properly instantiated to a `PersonalRelationship` object. + assert isinstance(query_data[0][0][0], PersonalRelationship) + + +@mark_async_test +async def test_properly_inherited_relationship(): + """ + A relationship class extends an existing relationship model that must extended the same previously associated + relationship label. + """ + + # Extends an existing relationship by adding the "relationship_strength" attribute. + # `ExtendedPersonalRelationship` will now substitute `PersonalRelationship` EVERYWHERE in the system. + class ExtendedPersonalRelationship(PersonalRelationship): + relationship_strength = FloatProperty(default=random.random) + + # Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson + class ExtendedSomePerson(SomePerson): + friends_with = AsyncRelationshipTo( + "BaseOtherPerson", + "FRIENDS_WITH", + model=ExtendedPersonalRelationship, + ) + + # Test specific data + A = await ExtendedSomePerson(name="Michael Knight", car_color="Black").save() + B = await ExtendedSomePerson(name="Luke Duke", car_color="Orange").save() + C = await ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save() + + await A.friends_with.connect(B) + await A.friends_with.connect(C) + + query_data = await adb.cypher_query( + "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " + "RETURN DISTINCT r", + resolve_objects=True, + ) + + assert isinstance(query_data[0][0][0], ExtendedPersonalRelationship) + + +def test_improperly_inherited_relationship(): + """ + Attempting to re-define an existing relationship with a completely unrelated class. + :return: + """ + + class NewRelationship(AsyncStructuredRel): + profile_match_factor = FloatProperty() + + with pytest.raises( + RelationshipClassRedefined, + match=r"Relationship of type .* redefined as .*", + ): + + class NewSomePerson(SomePerson): + friends_with = AsyncRelationshipTo( + "BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship + ) + + +@mark_async_test +async def test_resolve_inexistent_relationship(): + """ + Attempting to resolve an inexistent relationship should raise an exception + :return: + """ + + # Forget about the FRIENDS_WITH Relationship. + del adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] + + with pytest.raises( + RelationshipClassNotDefined, + match=r"Relationship of type .* does not resolve to any of the known objects.*", + ): + query_data = await adb.cypher_query( + "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " + "RETURN DISTINCT r", + resolve_objects=True, + ) diff --git a/test/async_/test_issue600.py b/test/async_/test_issue600.py new file mode 100644 index 00000000..5f66f39e --- /dev/null +++ b/test/async_/test_issue600.py @@ -0,0 +1,87 @@ +""" +Provides a test case for issue 600 - "Pull request #592 cause an error in case of relationship inharitance". + +The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/600 +""" + +from test._async_compat import mark_async_test + +from neomodel import AsyncRelationship, AsyncStructuredNode, AsyncStructuredRel + +try: + basestring +except NameError: + basestring = str + + +class Class1(AsyncStructuredRel): + pass + + +class SubClass1(Class1): + pass + + +class SubClass2(Class1): + pass + + +class RelationshipDefinerSecondSibling(AsyncStructuredNode): + rel_1 = AsyncRelationship( + "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 + ) + rel_2 = AsyncRelationship( + "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1 + ) + rel_3 = AsyncRelationship( + "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2 + ) + + +class RelationshipDefinerParentLast(AsyncStructuredNode): + rel_2 = AsyncRelationship( + "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 + ) + rel_3 = AsyncRelationship( + "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2 + ) + rel_1 = AsyncRelationship( + "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1 + ) + + +# Test cases +@mark_async_test +async def test_relationship_definer_second_sibling(): + # Create a few entities + A = (await RelationshipDefinerSecondSibling.get_or_create({}))[0] + B = (await RelationshipDefinerSecondSibling.get_or_create({}))[0] + C = (await RelationshipDefinerSecondSibling.get_or_create({}))[0] + + # Add connections + await A.rel_1.connect(B) + await B.rel_2.connect(C) + await C.rel_3.connect(A) + + # Clean up + await A.delete() + await B.delete() + await C.delete() + + +@mark_async_test +async def test_relationship_definer_parent_last(): + # Create a few entities + A = (await RelationshipDefinerParentLast.get_or_create({}))[0] + B = (await RelationshipDefinerParentLast.get_or_create({}))[0] + C = (await RelationshipDefinerParentLast.get_or_create({}))[0] + + # Add connections + await A.rel_1.connect(B) + await B.rel_2.connect(C) + await C.rel_3.connect(A) + + # Clean up + await A.delete() + await B.delete() + await C.delete() diff --git a/test/async_/test_label_drop.py b/test/async_/test_label_drop.py new file mode 100644 index 00000000..3d64050b --- /dev/null +++ b/test/async_/test_label_drop.py @@ -0,0 +1,47 @@ +from test._async_compat import mark_async_test + +from neo4j.exceptions import ClientError + +from neomodel import AsyncStructuredNode, StringProperty, adb + + +class ConstraintAndIndex(AsyncStructuredNode): + name = StringProperty(unique_index=True) + last_name = StringProperty(index=True) + + +@mark_async_test +async def test_drop_labels(): + await adb.install_labels(ConstraintAndIndex) + constraints_before = await adb.list_constraints() + indexes_before = await adb.list_indexes(exclude_token_lookup=True) + + assert len(constraints_before) > 0 + assert len(indexes_before) > 0 + + await adb.remove_all_labels() + + constraints = await adb.list_constraints() + indexes = await adb.list_indexes(exclude_token_lookup=True) + + assert len(constraints) == 0 + assert len(indexes) == 0 + + # Recreating all old constraints and indexes + for constraint in constraints_before: + constraint_type_clause = "UNIQUE" + if constraint["type"] == "NODE_PROPERTY_EXISTENCE": + constraint_type_clause = "NOT NULL" + elif constraint["type"] == "NODE_KEY": + constraint_type_clause = "NODE KEY" + + await adb.cypher_query( + f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}' + ) + for index in indexes_before: + try: + await adb.cypher_query( + f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})' + ) + except ClientError: + pass diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py new file mode 100644 index 00000000..2d710a19 --- /dev/null +++ b/test/async_/test_label_install.py @@ -0,0 +1,193 @@ +from test._async_compat import mark_async_test + +import pytest + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + StringProperty, + UniqueIdProperty, + adb, +) +from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported + + +class NodeWithIndex(AsyncStructuredNode): + name = StringProperty(index=True) + + +class NodeWithConstraint(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + +class NodeWithRelationship(AsyncStructuredNode): + ... + + +class IndexedRelationship(AsyncStructuredRel): + indexed_rel_prop = StringProperty(index=True) + + +class OtherNodeWithRelationship(AsyncStructuredNode): + has_rel = AsyncRelationshipTo( + NodeWithRelationship, "INDEXED_REL", model=IndexedRelationship + ) + + +class AbstractNode(AsyncStructuredNode): + __abstract_node__ = True + name = StringProperty(unique_index=True) + + +class SomeNotUniqueNode(AsyncStructuredNode): + id_ = UniqueIdProperty(db_property="id") + + +@mark_async_test +async def test_install_all(): + await adb.drop_constraints() + await adb.drop_indexes() + await adb.install_labels(AbstractNode) + # run install all labels + await adb.install_all_labels() + + indexes = await adb.list_indexes() + index_names = [index["name"] for index in indexes] + assert "index_INDEXED_REL_indexed_rel_prop" in index_names + + constraints = await adb.list_constraints() + constraint_names = [constraint["name"] for constraint in constraints] + assert "constraint_unique_NodeWithConstraint_name" in constraint_names + assert "constraint_unique_SomeNotUniqueNode_id" in constraint_names + + # remove constraint for above test + await _drop_constraints_for_label_and_property("NoConstraintsSetup", "name") + + +@mark_async_test +async def test_install_label_twice(capsys): + await adb.drop_constraints() + await adb.drop_indexes() + expected_std_out = ( + "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" + ) + await adb.install_labels(AbstractNode) + await adb.install_labels(AbstractNode) + + await adb.install_labels(NodeWithIndex) + await adb.install_labels(NodeWithIndex, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + await adb.install_labels(NodeWithConstraint) + await adb.install_labels(NodeWithConstraint, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + await adb.install_labels(OtherNodeWithRelationship) + await adb.install_labels(OtherNodeWithRelationship, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + if await adb.version_is_higher_than("5.7"): + + class UniqueIndexRelationship(AsyncStructuredRel): + unique_index_rel_prop = StringProperty(unique_index=True) + + class OtherNodeWithUniqueIndexRelationship(AsyncStructuredNode): + has_rel = AsyncRelationshipTo( + NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship + ) + + await adb.install_labels(OtherNodeWithUniqueIndexRelationship) + await adb.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) + captured = capsys.readouterr() + assert expected_std_out in captured.out + + +@mark_async_test +async def test_install_labels_db_property(capsys): + await adb.drop_constraints() + await adb.install_labels(SomeNotUniqueNode, quiet=False) + captured = capsys.readouterr() + assert "id" in captured.out + # make sure that the id_ constraint doesn't exist + constraint_names = await _drop_constraints_for_label_and_property( + "SomeNotUniqueNode", "id_" + ) + assert constraint_names == [] + # make sure the id constraint exists and can be removed + await _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") + + +@mark_async_test +async def test_relationship_unique_index_not_supported(): + if await adb.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + + class UniqueIndexRelationship(AsyncStructuredRel): + name = StringProperty(unique_index=True) + + class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): + pass + + with pytest.raises( + FeatureNotSupported, match=r".*Please upgrade to Neo4j 5.7 or higher" + ): + + class NodeWithUniqueIndexRelationship(AsyncStructuredNode): + has_rel = AsyncRelationshipTo( + TargetNodeForUniqueIndexRelationship, + "UNIQUE_INDEX_REL", + model=UniqueIndexRelationship, + ) + + await adb.install_labels(NodeWithUniqueIndexRelationship) + + +@mark_async_test +async def test_relationship_unique_index(): + if not await adb.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + + class UniqueIndexRelationshipBis(AsyncStructuredRel): + name = StringProperty(unique_index=True) + + class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode): + pass + + class NodeWithUniqueIndexRelationship(AsyncStructuredNode): + has_rel = AsyncRelationshipTo( + TargetNodeForUniqueIndexRelationship, + "UNIQUE_INDEX_REL_BIS", + model=UniqueIndexRelationshipBis, + ) + + await adb.install_labels(NodeWithUniqueIndexRelationship) + node1 = await NodeWithUniqueIndexRelationship().save() + node2 = await TargetNodeForUniqueIndexRelationship().save() + node3 = await TargetNodeForUniqueIndexRelationship().save() + rel1 = await node1.has_rel.connect(node2, {"name": "rel1"}) + + with pytest.raises( + ConstraintValidationFailed, + match=r".*already exists with type `UNIQUE_INDEX_REL_BIS` and property `name`.*", + ): + rel2 = await node1.has_rel.connect(node3, {"name": "rel1"}) + + +async def _drop_constraints_for_label_and_property( + label: str = None, property: str = None +): + results, meta = await adb.cypher_query("SHOW CONSTRAINTS") + results_as_dict = [dict(zip(meta, row)) for row in results] + constraint_names = [ + constraint + for constraint in results_as_dict + if constraint["labelsOrTypes"] == label and constraint["properties"] == property + ] + for constraint_name in constraint_names: + await adb.cypher_query(f"DROP CONSTRAINT {constraint_name}") + + return constraint_names diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py new file mode 100644 index 00000000..87c1ca8e --- /dev/null +++ b/test/async_/test_match_api.py @@ -0,0 +1,569 @@ +from datetime import datetime +from test._async_compat import mark_async_test + +from pytest import raises + +from neomodel import ( + INCOMING, + AsyncRelationshipFrom, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + DateTimeProperty, + IntegerProperty, + Q, + StringProperty, +) +from neomodel._async_compat.util import AsyncUtil +from neomodel.async_.match import ( + AsyncNodeSet, + AsyncQueryBuilder, + AsyncTraversal, + Optional, +) +from neomodel.exceptions import MultipleNodesReturned + + +class SupplierRel(AsyncStructuredRel): + since = DateTimeProperty(default=datetime.now) + courier = StringProperty() + + +class Supplier(AsyncStructuredNode): + name = StringProperty() + delivery_cost = IntegerProperty() + coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS") + + +class Species(AsyncStructuredNode): + name = StringProperty() + coffees = AsyncRelationshipFrom( + "Coffee", "COFFEE SPECIES", model=AsyncStructuredRel + ) + + +class Coffee(AsyncStructuredNode): + name = StringProperty(unique_index=True) + price = IntegerProperty() + suppliers = AsyncRelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel) + species = AsyncRelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel) + id_ = IntegerProperty() + + +class Extension(AsyncStructuredNode): + extension = AsyncRelationshipTo("Extension", "extension") + + +@mark_async_test +async def test_filter_exclude_via_labels(): + await Coffee(name="Java", price=99).save() + + node_set = AsyncNodeSet(Coffee) + qb = await AsyncQueryBuilder(node_set).build_ast() + + results = await qb._execute() + + assert "(coffee:Coffee)" in qb._ast.match + assert qb._ast.result_class + assert len(results) == 1 + assert isinstance(results[0], Coffee) + assert results[0].name == "Java" + + # with filter and exclude + await Coffee(name="Kenco", price=3).save() + node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java") + qb = await AsyncQueryBuilder(node_set).build_ast() + + results = await qb._execute() + assert "(coffee:Coffee)" in qb._ast.match + assert "NOT" in qb._ast.where[0] + assert len(results) == 1 + assert results[0].name == "Kenco" + + +@mark_async_test +async def test_simple_has_via_label(): + nescafe = await Coffee(name="Nescafe", price=99).save() + tesco = await Supplier(name="Tesco", delivery_cost=2).save() + await nescafe.suppliers.connect(tesco) + + ns = AsyncNodeSet(Coffee).has(suppliers=True) + qb = await AsyncQueryBuilder(ns).build_ast() + results = await qb._execute() + assert "COFFEE SUPPLIERS" in qb._ast.where[0] + assert len(results) == 1 + assert results[0].name == "Nescafe" + + await Coffee(name="nespresso", price=99).save() + ns = AsyncNodeSet(Coffee).has(suppliers=False) + qb = await AsyncQueryBuilder(ns).build_ast() + results = await qb._execute() + assert len(results) > 0 + assert "NOT" in qb._ast.where[0] + + +@mark_async_test +async def test_get(): + await Coffee(name="1", price=3).save() + assert await Coffee.nodes.get(name="1") + + with raises(Coffee.DoesNotExist): + await Coffee.nodes.get(name="2") + + await Coffee(name="2", price=3).save() + + with raises(MultipleNodesReturned): + await Coffee.nodes.get(price=3) + + +@mark_async_test +async def test_simple_traverse_with_filter(): + nescafe = await Coffee(name="Nescafe2", price=99).save() + tesco = await Supplier(name="Sainsburys", delivery_cost=2).save() + await nescafe.suppliers.connect(tesco) + + qb = AsyncQueryBuilder( + AsyncNodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()) + ) + + _ast = await qb.build_ast() + results = await _ast._execute() + + assert qb._ast.lookup + assert qb._ast.match + assert qb._ast.return_clause.startswith("suppliers") + assert len(results) == 1 + assert results[0].name == "Sainsburys" + + +@mark_async_test +async def test_double_traverse(): + nescafe = await Coffee(name="Nescafe plus", price=99).save() + tesco = await Supplier(name="Asda", delivery_cost=2).save() + await nescafe.suppliers.connect(tesco) + await tesco.coffees.connect(await Coffee(name="Decafe", price=2).save()) + + ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match() + qb = await AsyncQueryBuilder(ns).build_ast() + + results = await qb._execute() + assert len(results) == 2 + assert results[0].name == "Decafe" + assert results[1].name == "Nescafe plus" + + +@mark_async_test +async def test_count(): + await Coffee(name="Nescafe Gold", price=99).save() + ast = await AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast() + count = await ast._count() + assert count > 0 + + await Coffee(name="Kawa", price=27).save() + node_set = AsyncNodeSet(source=Coffee) + node_set.skip = 1 + node_set.limit = 1 + ast = await AsyncQueryBuilder(node_set).build_ast() + count = await ast._count() + assert count == 1 + + +@mark_async_test +async def test_len_and_iter_and_bool(): + iterations = 0 + + await Coffee(name="Icelands finest").save() + + for c in await Coffee.nodes: + iterations += 1 + await c.delete() + + assert iterations > 0 + + assert len(await Coffee.nodes) == 0 + + +@mark_async_test +async def test_slice(): + for c in await Coffee.nodes: + await c.delete() + + await Coffee(name="Icelands finest").save() + await Coffee(name="Britains finest").save() + await Coffee(name="Japans finest").save() + + # Branching tests because async needs extra brackets + if AsyncUtil.is_async_code: + assert len(list((await Coffee.nodes)[1:])) == 2 + assert len(list((await Coffee.nodes)[:1])) == 1 + assert isinstance((await Coffee.nodes)[1], Coffee) + assert isinstance((await Coffee.nodes)[0], Coffee) + assert len(list((await Coffee.nodes)[1:2])) == 1 + else: + assert len(list(Coffee.nodes[1:])) == 2 + assert len(list(Coffee.nodes[:1])) == 1 + assert isinstance(Coffee.nodes[1], Coffee) + assert isinstance(Coffee.nodes[0], Coffee) + assert len(list(Coffee.nodes[1:2])) == 1 + + +@mark_async_test +async def test_issue_208(): + # calls to match persist across queries. + + b = await Coffee(name="basics").save() + l = await Supplier(name="lidl").save() + a = await Supplier(name="aldi").save() + + await b.suppliers.connect(l, {"courier": "fedex"}) + await b.suppliers.connect(a, {"courier": "dhl"}) + + assert len(await b.suppliers.match(courier="fedex")) + assert len(await b.suppliers.match(courier="dhl")) + + +@mark_async_test +async def test_issue_589(): + node1 = await Extension().save() + node2 = await Extension().save() + assert node2 not in await node1.extension + await node1.extension.connect(node2) + assert node2 in await node1.extension + + +@mark_async_test +async def test_contains(): + expensive = await Coffee(price=1000, name="Pricey").save() + asda = await Coffee(name="Asda", price=1).save() + + assert expensive in await Coffee.nodes.filter(price__gt=999) + assert asda not in await Coffee.nodes.filter(price__gt=999) + + # bad value raises + with raises(ValueError, match=r"Expecting StructuredNode instance"): + if AsyncUtil.is_async_code: + assert await Coffee.nodes.check_contains(2) + else: + assert 2 in Coffee.nodes + + # unsaved + with raises(ValueError, match=r"Unsaved node"): + if AsyncUtil.is_async_code: + assert await Coffee.nodes.check_contains(Coffee()) + else: + assert Coffee() in Coffee.nodes + + +@mark_async_test +async def test_order_by(): + for c in await Coffee.nodes: + await c.delete() + + c1 = await Coffee(name="Icelands finest", price=5).save() + c2 = await Coffee(name="Britains finest", price=10).save() + c3 = await Coffee(name="Japans finest", price=35).save() + + if AsyncUtil.is_async_code: + assert ((await Coffee.nodes.order_by("price"))[0]).price == 5 + assert ((await Coffee.nodes.order_by("-price"))[0]).price == 35 + else: + assert (Coffee.nodes.order_by("price")[0]).price == 5 + assert (Coffee.nodes.order_by("-price")[0]).price == 35 + + ns = Coffee.nodes.order_by("-price") + qb = await AsyncQueryBuilder(ns).build_ast() + assert qb._ast.order_by + ns = ns.order_by(None) + qb = await AsyncQueryBuilder(ns).build_ast() + assert not qb._ast.order_by + ns = ns.order_by("?") + qb = await AsyncQueryBuilder(ns).build_ast() + assert qb._ast.with_clause == "coffee, rand() as r" + assert qb._ast.order_by == "r" + + with raises( + ValueError, + match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", + ): + await Coffee.nodes.order_by("id") + + # Test order by on a relationship + l = await Supplier(name="lidl2").save() + await l.coffees.connect(c1) + await l.coffees.connect(c2) + await l.coffees.connect(c3) + + ordered_n = [n for n in await l.coffees.order_by("name")] + assert ordered_n[0] == c2 + assert ordered_n[1] == c1 + assert ordered_n[2] == c3 + + +@mark_async_test +async def test_extra_filters(): + for c in await Coffee.nodes: + await c.delete() + + c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() + c2 = await Coffee(name="Britains finest", price=10, id_=2).save() + c3 = await Coffee(name="Japans finest", price=35, id_=3).save() + c4 = await Coffee(name="US extra-fine", price=None, id_=4).save() + + coffees_5_10 = await Coffee.nodes.filter(price__in=[10, 5]) + assert len(coffees_5_10) == 2, "unexpected number of results" + assert c1 in coffees_5_10, "doesnt contain 5 price coffee" + assert c2 in coffees_5_10, "doesnt contain 10 price coffee" + + finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest") + assert len(finest_coffees) == 3, "unexpected number of results" + assert c1 in finest_coffees, "doesnt contain 1st finest coffee" + assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" + assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" + + unpriced_coffees = await Coffee.nodes.filter(price__isnull=True) + assert len(unpriced_coffees) == 1, "unexpected number of results" + assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" + + coffees_with_id_gte_3 = await Coffee.nodes.filter(id___gte=3) + assert len(coffees_with_id_gte_3) == 2, "unexpected number of results" + assert c3 in coffees_with_id_gte_3 + assert c4 in coffees_with_id_gte_3 + + with raises( + ValueError, + match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.", + ): + await Coffee.nodes.filter(elementId="4:xxx:111").all() + + +def test_traversal_definition_keys_are_valid(): + muckefuck = Coffee(name="Mukkefuck", price=1) + + with raises(ValueError): + AsyncTraversal( + muckefuck, + "a_name", + { + "node_class": Supplier, + "direction": INCOMING, + "relationship_type": "KNOWS", + "model": None, + }, + ) + + AsyncTraversal( + muckefuck, + "a_name", + { + "node_class": Supplier, + "direction": INCOMING, + "relation_type": "KNOWS", + "model": None, + }, + ) + + +@mark_async_test +async def test_empty_filters(): + """Test this case: + ``` + SomeModel.nodes.filter().filter(Q(arg1=val1)).all() + SomeModel.nodes.exclude().exclude(Q(arg1=val1)).all() + SomeModel.nodes.filter().filter(arg1=val1).all() + ``` + In django_rest_framework filter uses such as lazy function and + ``get_queryset`` function in ``GenericAPIView`` should returns + ``NodeSet`` object. + """ + + for c in await Coffee.nodes: + await c.delete() + + c1 = await Coffee(name="Super", price=5, id_=1).save() + c2 = await Coffee(name="Puper", price=10, id_=2).save() + + empty_filter = Coffee.nodes.filter() + + all_coffees = await empty_filter.all() + assert len(all_coffees) == 2, "unexpected number of results" + + filter_empty_filter = empty_filter.filter(price=5) + assert len(await filter_empty_filter.all()) == 1, "unexpected number of results" + assert ( + c1 in await filter_empty_filter.all() + ), "doesnt contain c1 in ``filter_empty_filter``" + + filter_q_empty_filter = empty_filter.filter(Q(price=5)) + assert len(await filter_empty_filter.all()) == 1, "unexpected number of results" + assert ( + c1 in await filter_empty_filter.all() + ), "doesnt contain c1 in ``filter_empty_filter``" + + +@mark_async_test +async def test_q_filters(): + # Test where no children and self.connector != conn ? + for c in await Coffee.nodes: + await c.delete() + + c1 = await Coffee(name="Icelands finest", price=5, id_=1).save() + c2 = await Coffee(name="Britains finest", price=10, id_=2).save() + c3 = await Coffee(name="Japans finest", price=35, id_=3).save() + c4 = await Coffee(name="US extra-fine", price=None, id_=4).save() + c5 = await Coffee(name="Latte", price=35, id_=5).save() + c6 = await Coffee(name="Cappuccino", price=35, id_=6).save() + + coffees_5_10 = await Coffee.nodes.filter(Q(price=10) | Q(price=5)).all() + assert len(coffees_5_10) == 2, "unexpected number of results" + assert c1 in coffees_5_10, "doesnt contain 5 price coffee" + assert c2 in coffees_5_10, "doesnt contain 10 price coffee" + + coffees_5_6 = ( + await Coffee.nodes.filter(Q(name="Latte") | Q(name="Cappuccino")) + .filter(price=35) + .all() + ) + assert len(coffees_5_6) == 2, "unexpected number of results" + assert c5 in coffees_5_6, "doesnt contain 5 coffee" + assert c6 in coffees_5_6, "doesnt contain 6 coffee" + + coffees_5_6 = ( + await Coffee.nodes.filter(price=35) + .filter(Q(name="Latte") | Q(name="Cappuccino")) + .all() + ) + assert len(coffees_5_6) == 2, "unexpected number of results" + assert c5 in coffees_5_6, "doesnt contain 5 coffee" + assert c6 in coffees_5_6, "doesnt contain 6 coffee" + + finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest").all() + assert len(finest_coffees) == 3, "unexpected number of results" + assert c1 in finest_coffees, "doesnt contain 1st finest coffee" + assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" + assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" + + unpriced_coffees = await Coffee.nodes.filter(Q(price__isnull=True)).all() + assert len(unpriced_coffees) == 1, "unexpected number of results" + assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" + + coffees_with_id_gte_3 = await Coffee.nodes.filter(Q(id___gte=3)).all() + assert len(coffees_with_id_gte_3) == 4, "unexpected number of results" + assert c3 in coffees_with_id_gte_3 + assert c4 in coffees_with_id_gte_3 + assert c5 in coffees_with_id_gte_3 + assert c6 in coffees_with_id_gte_3 + + coffees_5_not_japans = await Coffee.nodes.filter( + Q(price__gt=5) & ~Q(name="Japans finest") + ).all() + assert c3 not in coffees_5_not_japans + + empty_Q_condition = await Coffee.nodes.filter(Q(price=5) | Q()).all() + assert ( + len(empty_Q_condition) == 1 + ), "undefined Q leading to unexpected number of results" + assert c1 in empty_Q_condition + + combined_coffees = await Coffee.nodes.filter( + Q(price=35), Q(name="Latte") | Q(name="Cappuccino") + ).all() + assert len(combined_coffees) == 2 + assert c5 in combined_coffees + assert c6 in combined_coffees + assert c3 not in combined_coffees + + class QQ: + pass + + with raises(TypeError): + wrong_Q = await Coffee.nodes.filter(Q(price=5) | QQ()).all() + + +def test_qbase(): + test_print_out = str(Q(price=5) | Q(price=10)) + test_repr = repr(Q(price=5) | Q(price=10)) + assert test_print_out == "(OR: ('price', 5), ('price', 10))" + assert test_repr == "" + + assert ("price", 5) in (Q(price=5) | Q(price=10)) + + test_hash = set([Q(price_lt=30) | ~Q(price=5), Q(price_lt=30) | ~Q(price=5)]) + assert len(test_hash) == 1 + + +@mark_async_test +async def test_traversal_filter_left_hand_statement(): + nescafe = await Coffee(name="Nescafe2", price=99).save() + nescafe_gold = await Coffee(name="Nescafe gold", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + biedronka = await Supplier(name="Biedronka", delivery_cost=5).save() + lidl = await Supplier(name="Lidl", delivery_cost=3).save() + + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(biedronka) + await nescafe_gold.suppliers.connect(lidl) + + lidl_supplier = ( + await AsyncNodeSet(Coffee.nodes.filter(price=11).suppliers) + .filter(delivery_cost=3) + .all() + ) + + assert lidl in lidl_supplier + + +@mark_async_test +async def test_fetch_relations(): + arabica = await Species(name="Arabica").save() + robusta = await Species(name="Robusta").save() + nescafe = await Coffee(name="Nescafe 1000", price=99).save() + nescafe_gold = await Coffee(name="Nescafe 1001", price=11).save() + + tesco = await Supplier(name="Sainsburys", delivery_cost=3).save() + await nescafe.suppliers.connect(tesco) + await nescafe_gold.suppliers.connect(tesco) + await nescafe.species.connect(arabica) + + result = ( + await Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert arabica in result[0] + assert robusta not in result[0] + assert tesco in result[0] + assert nescafe in result[0] + assert nescafe_gold not in result[0] + + result = ( + await Species.nodes.filter(name="Robusta") + .fetch_relations(Optional("coffees__suppliers")) + .all() + ) + assert result[0][0] is None + + if AsyncUtil.is_async_code: + count = ( + await Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .get_len() + ) + assert count == 1 + + assert ( + await Supplier.nodes.fetch_relations("coffees__species") + .filter(name="Sainsburys") + .check_contains(tesco) + ) + else: + count = len( + Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert count == 1 + + assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( + name="Sainsburys" + ) diff --git a/test/async_/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py new file mode 100644 index 00000000..48c0e8c4 --- /dev/null +++ b/test/async_/test_migration_neo4j_5.py @@ -0,0 +1,80 @@ +from test._async_compat import mark_async_test + +import pytest + +from neomodel import ( + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, + adb, +) + + +class Album(AsyncStructuredNode): + name = StringProperty() + + +class Released(AsyncStructuredRel): + year = IntegerProperty() + + +class Band(AsyncStructuredNode): + name = StringProperty() + released = AsyncRelationshipTo(Album, relation_type="RELEASED", model=Released) + + +@mark_async_test +async def test_read_elements_id(): + the_hives = await Band(name="The Hives").save() + lex_hives = await Album(name="Lex Hives").save() + released_rel = await the_hives.released.connect(lex_hives) + + # Validate element_id properties + assert lex_hives.element_id == (await the_hives.released.single()).element_id + assert released_rel._start_node_element_id == the_hives.element_id + assert released_rel._end_node_element_id == lex_hives.element_id + + # Validate id properties + # Behaviour is dependent on Neo4j version + db_version = await adb.database_version + if db_version.startswith("4"): + # Nodes' ids + assert lex_hives.id == int(lex_hives.element_id) + assert lex_hives.id == (await the_hives.released.single()).id + # Relationships' ids + assert isinstance(released_rel.element_id, str) + assert int(released_rel.element_id) == released_rel.id + assert released_rel._start_node_id == int(the_hives.element_id) + assert released_rel._end_node_id == int(lex_hives.element_id) + else: + # Nodes' ids + expected_error_type = ValueError + expected_error_message = "id is deprecated in Neo4j version 5, please migrate to element_id\. If you use the id in a Cypher query, replace id\(\) by elementId\(\)\." + assert isinstance(lex_hives.element_id, str) + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + lex_hives.id + + # Relationships' ids + assert isinstance(released_rel.element_id, str) + assert isinstance(released_rel._start_node_element_id, str) + assert isinstance(released_rel._end_node_element_id, str) + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + released_rel.id + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + released_rel._start_node_id + with pytest.raises( + expected_error_type, + match=expected_error_message, + ): + released_rel._end_node_id diff --git a/test/async_/test_models.py b/test/async_/test_models.py new file mode 100644 index 00000000..b9bb2e44 --- /dev/null +++ b/test/async_/test_models.py @@ -0,0 +1,360 @@ +from __future__ import print_function + +from datetime import datetime +from test._async_compat import mark_async_test + +from pytest import raises + +from neomodel import ( + AsyncStructuredNode, + AsyncStructuredRel, + DateProperty, + IntegerProperty, + StringProperty, + adb, +) +from neomodel.exceptions import RequiredProperty, UniqueProperty + + +class User(AsyncStructuredNode): + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + @property + def email_alias(self): + return self.email + + @email_alias.setter # noqa + def email_alias(self, value): + self.email = value + + +class NodeWithoutProperty(AsyncStructuredNode): + pass + + +@mark_async_test +async def test_issue_233(): + class BaseIssue233(AsyncStructuredNode): + __abstract_node__ = True + + def __getitem__(self, item): + return self.__dict__[item] + + class Issue233(BaseIssue233): + uid = StringProperty(unique_index=True, required=True) + + i = await Issue233(uid="testgetitem").save() + assert i["uid"] == "testgetitem" + + +def test_issue_72(): + user = User(email="foo@bar.com") + assert user.age is None + + +@mark_async_test +async def test_required(): + with raises(RequiredProperty): + await User(age=3).save() + + +def test_repr_and_str(): + u = User(email="robin@test.com", age=3) + assert repr(u) == "" + assert str(u) == "{'email': 'robin@test.com', 'age': 3}" + + +@mark_async_test +async def test_get_and_get_or_none(): + u = User(email="robin@test.com", age=3) + assert await u.save() + rob = await User.nodes.get(email="robin@test.com") + assert rob.email == "robin@test.com" + assert rob.age == 3 + + rob = await User.nodes.get_or_none(email="robin@test.com") + assert rob.email == "robin@test.com" + + n = await User.nodes.get_or_none(email="robin@nothere.com") + assert n is None + + +@mark_async_test +async def test_first_and_first_or_none(): + u = User(email="matt@test.com", age=24) + assert await u.save() + u2 = User(email="tbrady@test.com", age=40) + assert await u2.save() + tbrady = await User.nodes.order_by("-age").first() + assert tbrady.email == "tbrady@test.com" + assert tbrady.age == 40 + + tbrady = await User.nodes.order_by("-age").first_or_none() + assert tbrady.email == "tbrady@test.com" + + n = await User.nodes.first_or_none(email="matt@nothere.com") + assert n is None + + +def test_bare_init_without_save(): + """ + If a node model is initialised without being saved, accessing its `element_id` should + return None. + """ + assert User().element_id is None + + +@mark_async_test +async def test_save_to_model(): + u = User(email="jim@test.com", age=3) + assert await u.save() + assert u.element_id is not None + assert u.email == "jim@test.com" + assert u.age == 3 + + +@mark_async_test +async def test_save_node_without_properties(): + n = NodeWithoutProperty() + assert await n.save() + assert n.element_id is not None + + +@mark_async_test +async def test_unique(): + await adb.install_labels(User) + await User(email="jim1@test.com", age=3).save() + with raises(UniqueProperty): + await User(email="jim1@test.com", age=3).save() + + +@mark_async_test +async def test_update_unique(): + u = await User(email="jimxx@test.com", age=3).save() + await u.save() # this shouldn't fail + + +@mark_async_test +async def test_update(): + user = await User(email="jim2@test.com", age=3).save() + assert user + user.email = "jim2000@test.com" + await user.save() + jim = await User.nodes.get(email="jim2000@test.com") + assert jim + assert jim.email == "jim2000@test.com" + + +@mark_async_test +async def test_save_through_magic_property(): + user = await User(email_alias="blah@test.com", age=8).save() + assert user.email_alias == "blah@test.com" + user = await User.nodes.get(email="blah@test.com") + assert user.email == "blah@test.com" + assert user.email_alias == "blah@test.com" + + user1 = await User(email="blah1@test.com", age=8).save() + assert user1.email_alias == "blah1@test.com" + user1.email_alias = "blah2@test.com" + assert await user1.save() + user2 = await User.nodes.get(email="blah2@test.com") + assert user2 + + +class Customer2(AsyncStructuredNode): + __label__ = "customers" + email = StringProperty(unique_index=True, required=True) + age = IntegerProperty(index=True) + + +@mark_async_test +async def test_not_updated_on_unique_error(): + await adb.install_labels(Customer2) + await Customer2(email="jim@bob.com", age=7).save() + test = await Customer2(email="jim1@bob.com", age=2).save() + test.email = "jim@bob.com" + with raises(UniqueProperty): + await test.save() + customers = await Customer2.nodes + assert customers[0].email != customers[1].email + assert (await Customer2.nodes.get(email="jim@bob.com")).age == 7 + assert (await Customer2.nodes.get(email="jim1@bob.com")).age == 2 + + +@mark_async_test +async def test_label_not_inherited(): + class Customer3(Customer2): + address = StringProperty() + + assert Customer3.__label__ == "Customer3" + c = await Customer3(email="test@test.com").save() + assert "customers" in await c.labels() + assert "Customer3" in await c.labels() + + c = await Customer2.nodes.get(email="test@test.com") + assert isinstance(c, Customer2) + assert "customers" in await c.labels() + assert "Customer3" in await c.labels() + + +@mark_async_test +async def test_refresh(): + c = await Customer2(email="my@email.com", age=16).save() + c.my_custom_prop = "value" + copy = await Customer2.nodes.get(email="my@email.com") + copy.age = 20 + await copy.save() + + assert c.age == 16 + + await c.refresh() + assert c.age == 20 + assert c.my_custom_prop == "value" + + c = Customer2.inflate(c.element_id) + c.age = 30 + await c.refresh() + + assert c.age == 20 + + _db_version = await adb.database_version + if _db_version.startswith("4"): + c = Customer2.inflate(999) + else: + c = Customer2.inflate("4:xxxxxx:999") + with raises(Customer2.DoesNotExist): + await c.refresh() + + +@mark_async_test +async def test_setting_value_to_none(): + c = await Customer2(email="alice@bob.com", age=42).save() + assert c.age is not None + + c.age = None + await c.save() + + copy = await Customer2.nodes.get(email="alice@bob.com") + assert copy.age is None + + +@mark_async_test +async def test_inheritance(): + class User(AsyncStructuredNode): + __abstract_node__ = True + name = StringProperty(unique_index=True) + + class Shopper(User): + balance = IntegerProperty(index=True) + + async def credit_account(self, amount): + self.balance = self.balance + int(amount) + await self.save() + + jim = await Shopper(name="jimmy", balance=300).save() + await jim.credit_account(50) + + assert Shopper.__label__ == "Shopper" + assert jim.balance == 350 + assert len(jim.inherited_labels()) == 1 + assert len(await jim.labels()) == 1 + assert (await jim.labels())[0] == "Shopper" + + +@mark_async_test +async def test_inherited_optional_labels(): + class BaseOptional(AsyncStructuredNode): + __optional_labels__ = ["Alive"] + name = StringProperty(unique_index=True) + + class ExtendedOptional(BaseOptional): + __optional_labels__ = ["RewardsMember"] + balance = IntegerProperty(index=True) + + async def credit_account(self, amount): + self.balance = self.balance + int(amount) + await self.save() + + henry = await ExtendedOptional(name="henry", balance=300).save() + await henry.credit_account(50) + + assert ExtendedOptional.__label__ == "ExtendedOptional" + assert henry.balance == 350 + assert len(henry.inherited_labels()) == 2 + assert len(await henry.labels()) == 2 + + assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"} + + +@mark_async_test +async def test_mixins(): + class UserMixin: + name = StringProperty(unique_index=True) + password = StringProperty() + + class CreditMixin: + balance = IntegerProperty(index=True) + + async def credit_account(self, amount): + self.balance = self.balance + int(amount) + await self.save() + + class Shopper2(AsyncStructuredNode, UserMixin, CreditMixin): + pass + + jim = await Shopper2(name="jimmy", balance=300).save() + await jim.credit_account(50) + + assert Shopper2.__label__ == "Shopper2" + assert jim.balance == 350 + assert len(jim.inherited_labels()) == 1 + assert len(await jim.labels()) == 1 + assert (await jim.labels())[0] == "Shopper2" + + +@mark_async_test +async def test_date_property(): + class DateTest(AsyncStructuredNode): + birthdate = DateProperty() + + user = await DateTest(birthdate=datetime.now()).save() + + +def test_reserved_property_keys(): + error_match = r".*is not allowed as it conflicts with neomodel internals.*" + with raises(ValueError, match=error_match): + + class ReservedPropertiesDeletedNode(AsyncStructuredNode): + deleted = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesIdNode(AsyncStructuredNode): + id = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesElementIdNode(AsyncStructuredNode): + element_id = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesIdRel(AsyncStructuredRel): + id = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesElementIdRel(AsyncStructuredRel): + element_id = StringProperty() + + error_match = r"Property names 'source' and 'target' are not allowed as they conflict with neomodel internals." + with raises(ValueError, match=error_match): + + class ReservedPropertiesSourceRel(AsyncStructuredRel): + source = StringProperty() + + with raises(ValueError, match=error_match): + + class ReservedPropertiesTargetRel(AsyncStructuredRel): + target = StringProperty() diff --git a/test/async_/test_multiprocessing.py b/test/async_/test_multiprocessing.py new file mode 100644 index 00000000..9bf46598 --- /dev/null +++ b/test/async_/test_multiprocessing.py @@ -0,0 +1,24 @@ +from multiprocessing.pool import ThreadPool as Pool +from test._async_compat import mark_async_test + +from neomodel import AsyncStructuredNode, StringProperty, adb + + +class ThingyMaBob(AsyncStructuredNode): + name = StringProperty(unique_index=True, required=True) + + +async def thing_create(name): + name = str(name) + (thing,) = await ThingyMaBob.get_or_create({"name": name}) + return thing.name, name + + +@mark_async_test +async def test_concurrency(): + with Pool(5) as p: + results = p.map(thing_create, range(50)) + for to_unpack in results: + returned, sent = await to_unpack + assert returned == sent + await adb.close_connection() diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py new file mode 100644 index 00000000..59a5e385 --- /dev/null +++ b/test/async_/test_paths.py @@ -0,0 +1,96 @@ +from test._async_compat import mark_async_test + +from neomodel import ( + AsyncNeomodelPath, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + StringProperty, + UniqueIdProperty, + adb, +) + + +class PersonLivesInCity(AsyncStructuredRel): + """ + Relationship with data that will be instantiated as "stand-alone" + """ + + some_num = IntegerProperty(index=True, default=12) + + +class CountryOfOrigin(AsyncStructuredNode): + code = StringProperty(unique_index=True, required=True) + + +class CityOfResidence(AsyncStructuredNode): + name = StringProperty(required=True) + country = AsyncRelationshipTo(CountryOfOrigin, "FROM_COUNTRY") + + +class PersonOfInterest(AsyncStructuredNode): + uid = UniqueIdProperty() + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True, default=0) + + country = AsyncRelationshipTo(CountryOfOrigin, "IS_FROM") + city = AsyncRelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) + + +@mark_async_test +async def test_path_instantiation(): + """ + Neo4j driver paths should be instantiated as neomodel paths, with all of + their nodes and relationships resolved to their Python objects wherever + such a mapping is available. + """ + + c1 = await CountryOfOrigin(code="GR").save() + c2 = await CountryOfOrigin(code="FR").save() + + ct1 = await CityOfResidence(name="Athens", country=c1).save() + ct2 = await CityOfResidence(name="Paris", country=c2).save() + + p1 = await PersonOfInterest(name="Bill", age=22).save() + await p1.country.connect(c1) + await p1.city.connect(ct1) + + p2 = await PersonOfInterest(name="Jean", age=28).save() + await p2.country.connect(c2) + await p2.city.connect(ct2) + + p3 = await PersonOfInterest(name="Bo", age=32).save() + await p3.country.connect(c1) + await p3.city.connect(ct2) + + p4 = await PersonOfInterest(name="Drop", age=16).save() + await p4.country.connect(c1) + await p4.city.connect(ct2) + + # Retrieve a single path + q = await adb.cypher_query( + "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + resolve_objects=True, + ) + + path_object = q[0][0][0] + path_nodes = path_object.nodes + path_rels = path_object.relationships + + assert type(path_object) is AsyncNeomodelPath + assert type(path_nodes[0]) is CityOfResidence + assert type(path_nodes[1]) is PersonOfInterest + assert type(path_nodes[2]) is CountryOfOrigin + + assert type(path_rels[0]) is PersonLivesInCity + assert type(path_rels[1]) is AsyncStructuredRel + + await c1.delete() + await c2.delete() + await ct1.delete() + await ct2.delete() + await p1.delete() + await p2.delete() + await p3.delete() + await p4.delete() diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py new file mode 100644 index 00000000..0679cf89 --- /dev/null +++ b/test/async_/test_properties.py @@ -0,0 +1,453 @@ +from datetime import date, datetime +from test._async_compat import mark_async_test + +from pytest import mark, raises +from pytz import timezone + +from neomodel import AsyncStructuredNode, adb +from neomodel.exceptions import ( + DeflateError, + InflateError, + RequiredProperty, + UniqueProperty, +) +from neomodel.properties import ( + ArrayProperty, + DateProperty, + DateTimeFormatProperty, + DateTimeProperty, + EmailProperty, + IntegerProperty, + JSONProperty, + NormalizedProperty, + RegexProperty, + StringProperty, + UniqueIdProperty, +) +from neomodel.util import _get_node_properties + + +class FooBar: + pass + + +def test_string_property_exceeds_max_length(): + """ + StringProperty is defined by two properties: `max_length` and `choices` that are mutually exclusive. Furthermore, + max_length must be a positive non-zero number. + """ + # Try to define a property that has both choices and max_length + with raises(ValueError): + some_string_property = StringProperty( + choices={"One": "1", "Two": "2"}, max_length=22 + ) + + # Try to define a string property that has a negative zero length + with raises(ValueError): + another_string_property = StringProperty(max_length=-35) + + # Try to validate a long string + a_string_property = StringProperty(required=True, max_length=5) + with raises(ValueError): + a_string_property.normalize("The quick brown fox jumps over the lazy dog") + + # Try to validate a "valid" string, as per the max_length setting. + valid_string = "Owen" + normalised_string = a_string_property.normalize(valid_string) + assert ( + valid_string == normalised_string + ), "StringProperty max_length test passed but values do not match." + + +@mark_async_test +async def test_string_property_w_choice(): + class TestChoices(AsyncStructuredNode): + SEXES = {"F": "Female", "M": "Male", "O": "Other"} + sex = StringProperty(required=True, choices=SEXES) + + try: + await TestChoices(sex="Z").save() + except DeflateError as e: + assert "choice" in str(e) + else: + assert False, "DeflateError not raised." + + node = await TestChoices(sex="M").save() + assert node.get_sex_display() == "Male" + + +def test_deflate_inflate(): + prop = IntegerProperty(required=True) + prop.name = "age" + prop.owner = FooBar + + try: + prop.inflate("six") + except InflateError as e: + assert "inflate property" in str(e) + else: + assert False, "DeflateError not raised." + + try: + prop.deflate("six") + except DeflateError as e: + assert "deflate property" in str(e) + else: + assert False, "DeflateError not raised." + + +def test_datetimes_timezones(): + prop = DateTimeProperty() + prop.name = "foo" + prop.owner = FooBar + t = datetime.utcnow() + gr = timezone("Europe/Athens") + gb = timezone("Europe/London") + dt1 = gr.localize(t) + dt2 = gb.localize(t) + time1 = prop.inflate(prop.deflate(dt1)) + time2 = prop.inflate(prop.deflate(dt2)) + assert time1.utctimetuple() == dt1.utctimetuple() + assert time1.utctimetuple() < time2.utctimetuple() + assert time1.tzname() == "UTC" + + +def test_date(): + prop = DateProperty() + prop.name = "foo" + prop.owner = FooBar + somedate = date(2012, 12, 15) + assert prop.deflate(somedate) == "2012-12-15" + assert prop.inflate("2012-12-15") == somedate + + +def test_datetime_format(): + some_format = "%Y-%m-%d %H:%M:%S" + prop = DateTimeFormatProperty(format=some_format) + prop.name = "foo" + prop.owner = FooBar + some_datetime = datetime(2019, 3, 19, 15, 36, 25) + assert prop.deflate(some_datetime) == "2019-03-19 15:36:25" + assert prop.inflate("2019-03-19 15:36:25") == some_datetime + + +def test_datetime_exceptions(): + prop = DateTimeProperty() + prop.name = "created" + prop.owner = FooBar + faulty = "dgdsg" + + try: + prop.inflate(faulty) + except InflateError as e: + assert "inflate property" in str(e) + else: + assert False, "InflateError not raised." + + try: + prop.deflate(faulty) + except DeflateError as e: + assert "deflate property" in str(e) + else: + assert False, "DeflateError not raised." + + +def test_date_exceptions(): + prop = DateProperty() + prop.name = "date" + prop.owner = FooBar + faulty = "2012-14-13" + + try: + prop.inflate(faulty) + except InflateError as e: + assert "inflate property" in str(e) + else: + assert False, "InflateError not raised." + + try: + prop.deflate(faulty) + except DeflateError as e: + assert "deflate property" in str(e) + else: + assert False, "DeflateError not raised." + + +def test_json(): + prop = JSONProperty() + prop.name = "json" + prop.owner = FooBar + + value = {"test": [1, 2, 3]} + + assert prop.deflate(value) == '{"test": [1, 2, 3]}' + assert prop.inflate('{"test": [1, 2, 3]}') == value + + +@mark_async_test +async def test_default_value(): + class DefaultTestValue(AsyncStructuredNode): + name_xx = StringProperty(default="jim", index=True) + + a = DefaultTestValue() + assert a.name_xx == "jim" + await a.save() + + +@mark_async_test +async def test_default_value_callable(): + def uid_generator(): + return "xx" + + class DefaultTestValueTwo(AsyncStructuredNode): + uid = StringProperty(default=uid_generator, index=True) + + a = await DefaultTestValueTwo().save() + assert a.uid == "xx" + + +@mark_async_test +async def test_default_value_callable_type(): + # check our object gets converted to str without serializing and reload + def factory(): + class Foo: + def __str__(self): + return "123" + + return Foo() + + class DefaultTestValueThree(AsyncStructuredNode): + uid = StringProperty(default=factory, index=True) + + x = DefaultTestValueThree() + assert x.uid == "123" + await x.save() + assert x.uid == "123" + await x.refresh() + assert x.uid == "123" + + +@mark_async_test +async def test_independent_property_name(): + class TestDBNamePropertyNode(AsyncStructuredNode): + name_ = StringProperty(db_property="name") + + x = TestDBNamePropertyNode() + x.name_ = "jim" + await x.save() + + # check database property name on low level + results, meta = await adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n") + node_properties = _get_node_properties(results[0][0]) + assert node_properties["name"] == "jim" + + node_properties = _get_node_properties(results[0][0]) + assert not "name_" in node_properties + assert not hasattr(x, "name") + assert hasattr(x, "name_") + assert (await TestDBNamePropertyNode.nodes.filter(name_="jim").all())[ + 0 + ].name_ == x.name_ + assert (await TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ + + await x.delete() + + +@mark_async_test +async def test_independent_property_name_get_or_create(): + class TestNode(AsyncStructuredNode): + uid = UniqueIdProperty() + name_ = StringProperty(db_property="name", required=True) + + # create the node + await TestNode.get_or_create({"uid": 123, "name_": "jim"}) + # test that the node is retrieved correctly + x = (await TestNode.get_or_create({"uid": 123, "name_": "jim"}))[0] + + # check database property name on low level + results, _ = await adb.cypher_query("MATCH (n:TestNode) RETURN n") + node_properties = _get_node_properties(results[0][0]) + assert node_properties["name"] == "jim" + assert "name_" not in node_properties + + # delete node afterwards + await x.delete() + + +@mark.parametrize("normalized_class", (NormalizedProperty,)) +def test_normalized_property(normalized_class): + class TestProperty(normalized_class): + def normalize(self, value): + self._called_with = value + self._called = True + return value + "bar" + + inflate = TestProperty() + inflate_res = inflate.inflate("foo") + assert getattr(inflate, "_called", False) + assert getattr(inflate, "_called_with", None) == "foo" + assert inflate_res == "foobar" + + deflate = TestProperty() + deflate_res = deflate.deflate("bar") + assert getattr(deflate, "_called", False) + assert getattr(deflate, "_called_with", None) == "bar" + assert deflate_res == "barbar" + + default = TestProperty(default="qux") + default_res = default.default_value() + assert getattr(default, "_called", False) + assert getattr(default, "_called_with", None) == "qux" + assert default_res == "quxbar" + + +def test_regex_property(): + class MissingExpression(RegexProperty): + pass + + with raises(ValueError): + MissingExpression() + + class TestProperty(RegexProperty): + name = "test" + owner = object() + expression = r"\w+ \w+$" + + def normalize(self, value): + self._called = True + return super().normalize(value) + + prop = TestProperty() + result = prop.inflate("foo bar") + assert getattr(prop, "_called", False) + assert result == "foo bar" + + with raises(DeflateError): + prop.deflate("qux") + + +def test_email_property(): + prop = EmailProperty() + prop.name = "email" + prop.owner = object() + result = prop.inflate("foo@example.com") + assert result == "foo@example.com" + + with raises(DeflateError): + prop.deflate("foo@example") + + +@mark_async_test +async def test_uid_property(): + prop = UniqueIdProperty() + prop.name = "uid" + prop.owner = object() + myuid = prop.default_value() + assert len(myuid) + + class CheckMyId(AsyncStructuredNode): + uid = UniqueIdProperty() + + cmid = await CheckMyId().save() + assert len(cmid.uid) + + +class ArrayProps(AsyncStructuredNode): + uid = StringProperty(unique_index=True) + untyped_arr = ArrayProperty() + typed_arr = ArrayProperty(IntegerProperty()) + + +@mark_async_test +async def test_array_properties(): + # untyped + ap1 = await ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save() + assert "Tim" in ap1.untyped_arr + ap1 = await ArrayProps.nodes.get(uid="1") + assert "Tim" in ap1.untyped_arr + + # typed + try: + await ArrayProps(uid="2", typed_arr=["a", "b"]).save() + except DeflateError as e: + assert "unsaved node" in str(e) + else: + assert False, "DeflateError not raised." + + ap2 = await ArrayProps(uid="2", typed_arr=[1, 2]).save() + assert 1 in ap2.typed_arr + ap2 = await ArrayProps.nodes.get(uid="2") + assert 2 in ap2.typed_arr + + +def test_illegal_array_base_prop_raises(): + with raises(ValueError): + ArrayProperty(StringProperty(index=True)) + + +@mark_async_test +async def test_indexed_array(): + class IndexArray(AsyncStructuredNode): + ai = ArrayProperty(unique_index=True) + + b = await IndexArray(ai=[1, 2]).save() + c = await IndexArray.nodes.get(ai=[1, 2]) + assert b.element_id == c.element_id + + +@mark_async_test +async def test_unique_index_prop_not_required(): + class ConstrainedTestNode(AsyncStructuredNode): + required_property = StringProperty(required=True) + unique_property = StringProperty(unique_index=True) + unique_required_property = StringProperty(unique_index=True, required=True) + unconstrained_property = StringProperty() + + # Create a node with a missing required property + with raises(RequiredProperty): + x = ConstrainedTestNode(required_property="required", unique_property="unique") + await x.save() + + # Create a node with a missing unique (but not required) property. + x = ConstrainedTestNode() + x.required_property = "required" + x.unique_required_property = "unique and required" + x.unconstrained_property = "no contraints" + await x.save() + + # check database property name on low level + results, meta = await adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n") + node_properties = _get_node_properties(results[0][0]) + assert node_properties["unique_required_property"] == "unique and required" + + # delete node afterwards + await x.delete() + + +@mark_async_test +async def test_unique_index_prop_enforced(): + class UniqueNullableNameNode(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + await adb.install_labels(UniqueNullableNameNode) + # Nameless + x = UniqueNullableNameNode() + await x.save() + y = UniqueNullableNameNode() + await y.save() + + # Named + z = UniqueNullableNameNode(name="named") + await z.save() + with raises(UniqueProperty): + a = UniqueNullableNameNode(name="named") + await a.save() + + # Check nodes are in database + results, _ = await adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") + assert len(results) == 3 + + # Delete nodes afterwards + await x.delete() + await y.delete() + await z.delete() diff --git a/test/async_/test_relationship_models.py b/test/async_/test_relationship_models.py new file mode 100644 index 00000000..95fe4714 --- /dev/null +++ b/test/async_/test_relationship_models.py @@ -0,0 +1,171 @@ +from datetime import datetime +from test._async_compat import mark_async_test + +import pytz +from pytest import raises + +from neomodel import ( + AsyncRelationship, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + DateTimeProperty, + DeflateError, + StringProperty, +) +from neomodel._async_compat.util import AsyncUtil + +HOOKS_CALLED = {"pre_save": 0, "post_save": 0} + + +class FriendRel(AsyncStructuredRel): + since = DateTimeProperty(default=lambda: datetime.now(pytz.utc)) + + +class HatesRel(FriendRel): + reason = StringProperty() + + def pre_save(self): + HOOKS_CALLED["pre_save"] += 1 + + def post_save(self): + HOOKS_CALLED["post_save"] += 1 + + +class Badger(AsyncStructuredNode): + name = StringProperty(unique_index=True) + friend = AsyncRelationship("Badger", "FRIEND", model=FriendRel) + hates = AsyncRelationshipTo("Stoat", "HATES", model=HatesRel) + + +class Stoat(AsyncStructuredNode): + name = StringProperty(unique_index=True) + hates = AsyncRelationshipTo("Badger", "HATES", model=HatesRel) + + +@mark_async_test +async def test_either_connect_with_rel_model(): + paul = await Badger(name="Paul").save() + tom = await Badger(name="Tom").save() + + # creating rels + new_rel = await tom.friend.disconnect(paul) + new_rel = await tom.friend.connect(paul) + assert isinstance(new_rel, FriendRel) + assert isinstance(new_rel.since, datetime) + + # updating properties + new_rel.since = datetime.now(pytz.utc) + assert isinstance(await new_rel.save(), FriendRel) + + # start and end nodes are the opposite of what you'd expect when using either.. + # I've tried everything possible to correct this to no avail + paul = await new_rel.start_node() + tom = await new_rel.end_node() + assert paul.name == "Tom" + assert tom.name == "Paul" + + +@mark_async_test +async def test_direction_connect_with_rel_model(): + paul = await Badger(name="Paul the badger").save() + ian = await Stoat(name="Ian the stoat").save() + + rel = await ian.hates.connect( + paul, {"reason": "thinks paul should bath more often"} + ) + assert isinstance(rel.since, datetime) + assert isinstance(rel, FriendRel) + assert rel.reason.startswith("thinks") + rel.reason = "he smells" + await rel.save() + + ian = await rel.start_node() + assert isinstance(ian, Stoat) + paul = await rel.end_node() + assert isinstance(paul, Badger) + + assert ian.name.startswith("Ian") + assert paul.name.startswith("Paul") + + rel = await ian.hates.relationship(paul) + assert isinstance(rel, HatesRel) + assert isinstance(rel.since, datetime) + await rel.save() + + # test deflate checking + rel.since = "2:30pm" + with raises(DeflateError): + await rel.save() + + # check deflate check via connect + with raises(DeflateError): + await paul.hates.connect( + ian, + { + "reason": "thinks paul should bath more often", + "since": "2:30pm", + }, + ) + + +@mark_async_test +async def test_traversal_where_clause(): + phill = await Badger(name="Phill the badger").save() + tim = await Badger(name="Tim the badger").save() + bob = await Badger(name="Bob the badger").save() + rel = await tim.friend.connect(bob) + now = datetime.now(pytz.utc) + assert rel.since < now + rel2 = await tim.friend.connect(phill) + assert rel2.since > now + friends = tim.friend.match(since__gt=now) + assert len(await friends.all()) == 1 + + +@mark_async_test +async def test_multiple_rels_exist_issue_223(): + # check a badger can dislike a stoat for multiple reasons + phill = await Badger(name="Phill").save() + ian = await Stoat(name="Stoat").save() + + rel_a = await phill.hates.connect(ian, {"reason": "a"}) + rel_b = await phill.hates.connect(ian, {"reason": "b"}) + assert rel_a.element_id != rel_b.element_id + + if AsyncUtil.is_async_code: + ian_a = (await phill.hates.match(reason="a"))[0] + ian_b = (await phill.hates.match(reason="b"))[0] + else: + ian_a = phill.hates.match(reason="a")[0] + ian_b = phill.hates.match(reason="b")[0] + assert ian_a.element_id == ian_b.element_id + + +@mark_async_test +async def test_retrieve_all_rels(): + tom = await Badger(name="tom").save() + ian = await Stoat(name="ian").save() + + rel_a = await tom.hates.connect(ian, {"reason": "a"}) + rel_b = await tom.hates.connect(ian, {"reason": "b"}) + + rels = await tom.hates.all_relationships(ian) + assert len(rels) == 2 + assert rels[0].element_id in [rel_a.element_id, rel_b.element_id] + assert rels[1].element_id in [rel_a.element_id, rel_b.element_id] + + +@mark_async_test +async def test_save_hook_on_rel_model(): + HOOKS_CALLED["pre_save"] = 0 + HOOKS_CALLED["post_save"] = 0 + + paul = await Badger(name="PaulB").save() + ian = await Stoat(name="IanS").save() + + rel = await ian.hates.connect(paul, {"reason": "yadda yadda"}) + await rel.save() + + assert HOOKS_CALLED["pre_save"] == 2 + assert HOOKS_CALLED["post_save"] == 2 diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py new file mode 100644 index 00000000..9835e8b7 --- /dev/null +++ b/test/async_/test_relationships.py @@ -0,0 +1,209 @@ +from test._async_compat import mark_async_test + +from pytest import raises + +from neomodel import ( + AsyncOne, + AsyncRelationship, + AsyncRelationshipFrom, + AsyncRelationshipTo, + AsyncStructuredNode, + AsyncStructuredRel, + IntegerProperty, + Q, + StringProperty, + adb, +) + + +class PersonWithRels(AsyncStructuredNode): + name = StringProperty(unique_index=True) + age = IntegerProperty(index=True) + is_from = AsyncRelationshipTo("Country", "IS_FROM") + knows = AsyncRelationship("PersonWithRels", "KNOWS") + + @property + def special_name(self): + return self.name + + def special_power(self): + return "I have no powers" + + +class Country(AsyncStructuredNode): + code = StringProperty(unique_index=True) + inhabitant = AsyncRelationshipFrom(PersonWithRels, "IS_FROM") + president = AsyncRelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne) + + +class SuperHero(PersonWithRels): + power = StringProperty(index=True) + + def special_power(self): + return "I have powers" + + +@mark_async_test +async def test_actions_on_deleted_node(): + u = await PersonWithRels(name="Jim2", age=3).save() + await u.delete() + with raises(ValueError): + await u.is_from.connect(None) + + with raises(ValueError): + await u.is_from.get() + + with raises(ValueError): + await u.save() + + +@mark_async_test +async def test_bidirectional_relationships(): + u = await PersonWithRels(name="Jim", age=3).save() + assert u + + de = await Country(code="DE").save() + assert de + + assert not await u.is_from.all() + + assert u.is_from.__class__.__name__ == "AsyncZeroOrMore" + await u.is_from.connect(de) + + assert len(await u.is_from.all()) == 1 + + assert await u.is_from.is_connected(de) + + b = (await u.is_from.all())[0] + assert b.__class__.__name__ == "Country" + assert b.code == "DE" + + s = (await b.inhabitant.all())[0] + assert s.name == "Jim" + + await u.is_from.disconnect(b) + assert not await u.is_from.is_connected(b) + + +@mark_async_test +async def test_either_direction_connect(): + rey = await PersonWithRels(name="Rey", age=3).save() + sakis = await PersonWithRels(name="Sakis", age=3).save() + + await rey.knows.connect(sakis) + assert await rey.knows.is_connected(sakis) + assert await sakis.knows.is_connected(rey) + await sakis.knows.connect(rey) + + result, _ = await sakis.cypher( + f"""MATCH (us), (them) + WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(them)=$them + MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", + {"them": await adb.parse_element_id(rey.element_id)}, + ) + assert int(result[0][0]) == 1 + + rel = await rey.knows.relationship(sakis) + assert isinstance(rel, AsyncStructuredRel) + + rels = await rey.knows.all_relationships(sakis) + assert isinstance(rels[0], AsyncStructuredRel) + + +@mark_async_test +async def test_search_and_filter_and_exclude(): + fred = await PersonWithRels(name="Fred", age=13).save() + zz = await Country(code="ZZ").save() + zx = await Country(code="ZX").save() + zt = await Country(code="ZY").save() + await fred.is_from.connect(zz) + await fred.is_from.connect(zx) + await fred.is_from.connect(zt) + result = await fred.is_from.filter(code="ZX") + assert result[0].code == "ZX" + + result = await fred.is_from.filter(code="ZY") + assert result[0].code == "ZY" + + result = await fred.is_from.exclude(code="ZZ").exclude(code="ZY") + assert result[0].code == "ZX" and len(result) == 1 + + result = await fred.is_from.exclude(Q(code__contains="Y")) + assert len(result) == 2 + + result = await fred.is_from.filter(Q(code__contains="Z")) + assert len(result) == 3 + + +@mark_async_test +async def test_custom_methods(): + u = await PersonWithRels(name="Joe90", age=13).save() + assert u.special_power() == "I have no powers" + u = await SuperHero(name="Joe91", age=13, power="xxx").save() + assert u.special_power() == "I have powers" + assert u.special_name == "Joe91" + + +@mark_async_test +async def test_valid_reconnection(): + p = await PersonWithRels(name="ElPresidente", age=93).save() + assert p + + pp = await PersonWithRels(name="TheAdversary", age=33).save() + assert pp + + c = await Country(code="CU").save() + assert c + + await c.president.connect(p) + assert await c.president.is_connected(p) + + # the coup d'etat + await c.president.reconnect(p, pp) + assert await c.president.is_connected(pp) + + # reelection time + await c.president.reconnect(pp, pp) + assert await c.president.is_connected(pp) + + +@mark_async_test +async def test_valid_replace(): + brady = await PersonWithRels(name="Tom Brady", age=40).save() + assert brady + + gronk = await PersonWithRels(name="Rob Gronkowski", age=28).save() + assert gronk + + colbert = await PersonWithRels(name="Stephen Colbert", age=53).save() + assert colbert + + hanks = await PersonWithRels(name="Tom Hanks", age=61).save() + assert hanks + + await brady.knows.connect(gronk) + await brady.knows.connect(colbert) + assert len(await brady.knows.all()) == 2 + assert await brady.knows.is_connected(gronk) + assert await brady.knows.is_connected(colbert) + + await brady.knows.replace(hanks) + assert len(await brady.knows.all()) == 1 + assert await brady.knows.is_connected(hanks) + assert not await brady.knows.is_connected(gronk) + assert not await brady.knows.is_connected(colbert) + + +@mark_async_test +async def test_props_relationship(): + u = await PersonWithRels(name="Mar", age=20).save() + assert u + + c = await Country(code="AT").save() + assert c + + c2 = await Country(code="LA").save() + assert c2 + + with raises(NotImplementedError): + await c.inhabitant.connect(u, properties={"city": "Thessaloniki"}) diff --git a/test/async_/test_relative_relationships.py b/test/async_/test_relative_relationships.py new file mode 100644 index 00000000..7b283f84 --- /dev/null +++ b/test/async_/test_relative_relationships.py @@ -0,0 +1,24 @@ +from test._async_compat import mark_async_test +from test.async_.test_relationships import Country + +from neomodel import AsyncRelationshipTo, AsyncStructuredNode, StringProperty + + +class Cat(AsyncStructuredNode): + name = StringProperty() + # Relationship is defined using a relative class path + is_from = AsyncRelationshipTo(".test_relationships.Country", "IS_FROM") + + +@mark_async_test +async def test_relative_relationship(): + a = await Cat(name="snufkin").save() + assert a + + c = await Country(code="MG").save() + assert c + + # connecting an instance of the class defined above + # the next statement will fail if there's a type mismatch + await a.is_from.connect(c) + assert await a.is_from.is_connected(c) diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py new file mode 100644 index 00000000..59d523c5 --- /dev/null +++ b/test/async_/test_transactions.py @@ -0,0 +1,193 @@ +from test._async_compat import mark_async_test + +import pytest +from neo4j.api import Bookmarks +from neo4j.exceptions import ClientError, TransactionError +from pytest import raises + +from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty, adb + + +class APerson(AsyncStructuredNode): + name = StringProperty(unique_index=True) + + +@mark_async_test +async def test_rollback_and_commit_transaction(): + for p in await APerson.nodes: + await p.delete() + + await APerson(name="Roger").save() + + await adb.begin() + await APerson(name="Terry S").save() + await adb.rollback() + + assert len(await APerson.nodes) == 1 + + await adb.begin() + await APerson(name="Terry S").save() + await adb.commit() + + assert len(await APerson.nodes) == 2 + + +@adb.transaction +async def in_a_tx(*names): + for n in names: + await APerson(name=n).save() + + +@mark_async_test +async def test_transaction_decorator(): + await adb.install_labels(APerson) + for p in await APerson.nodes: + await p.delete() + + # should work + await in_a_tx("Roger") + + # should bail but raise correct error + with raises(UniqueProperty): + await in_a_tx("Jim", "Roger") + + assert "Jim" not in [p.name for p in await APerson.nodes] + + +@mark_async_test +async def test_transaction_as_a_context(): + async with adb.transaction: + await APerson(name="Tim").save() + + assert await APerson.nodes.filter(name="Tim") + + with raises(UniqueProperty): + async with adb.transaction: + await APerson(name="Tim").save() + + +@mark_async_test +async def test_query_inside_transaction(): + for p in await APerson.nodes: + await p.delete() + + async with adb.transaction: + await APerson(name="Alice").save() + await APerson(name="Bob").save() + + assert len([p.name for p in await APerson.nodes]) == 2 + + +@mark_async_test +async def test_read_transaction(): + await APerson(name="Johnny").save() + + async with adb.read_transaction: + people = await APerson.nodes + assert people + + with raises(TransactionError): + async with adb.read_transaction: + with raises(ClientError) as e: + await APerson(name="Gina").save() + assert e.value.code == "Neo.ClientError.Statement.AccessMode" + + +@mark_async_test +async def test_write_transaction(): + async with adb.write_transaction: + await APerson(name="Amelia").save() + + amelia = await APerson.nodes.get(name="Amelia") + assert amelia + + +@mark_async_test +async def double_transaction(): + await adb.begin() + with raises(SystemError, match=r"Transaction in progress"): + await adb.begin() + + await adb.rollback() + + +@adb.transaction.with_bookmark +async def in_a_tx_with_bookmark(*names): + for n in names: + await APerson(name=n).save() + + +@mark_async_test +async def test_bookmark_transaction_decorator(): + for p in await APerson.nodes: + await p.delete() + + # should work + result, bookmarks = await in_a_tx_with_bookmark("Ruth", bookmarks=None) + assert result is None + assert isinstance(bookmarks, Bookmarks) + + # should bail but raise correct error + with raises(UniqueProperty): + await in_a_tx_with_bookmark("Jane", "Ruth") + + assert "Jane" not in [p.name for p in await APerson.nodes] + + +@mark_async_test +async def test_bookmark_transaction_as_a_context(): + async with adb.transaction as transaction: + await APerson(name="Tanya").save() + assert isinstance(transaction.last_bookmark, Bookmarks) + + assert await APerson.nodes.filter(name="Tanya") + + with raises(UniqueProperty): + async with adb.transaction as transaction: + await APerson(name="Tanya").save() + assert not hasattr(transaction, "last_bookmark") + + +@pytest.fixture +def spy_on_db_begin(monkeypatch): + spy_calls = [] + original_begin = adb.begin + + def begin_spy(*args, **kwargs): + spy_calls.append((args, kwargs)) + return original_begin(*args, **kwargs) + + monkeypatch.setattr(adb, "begin", begin_spy) + return spy_calls + + +@mark_async_test +async def test_bookmark_passed_in_to_context(spy_on_db_begin): + transaction = adb.transaction + async with transaction: + pass + + assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) + last_bookmark = transaction.last_bookmark + + transaction.bookmarks = last_bookmark + async with transaction: + pass + assert spy_on_db_begin[-1] == ( + (), + {"access_mode": None, "bookmarks": last_bookmark}, + ) + + +@mark_async_test +async def test_query_inside_bookmark_transaction(): + for p in await APerson.nodes: + await p.delete() + + async with adb.transaction as transaction: + await APerson(name="Alice").save() + await APerson(name="Bob").save() + + assert len([p.name for p in await APerson.nodes]) == 2 + + assert isinstance(transaction.last_bookmark, Bookmarks) diff --git a/test/conftest.py b/test/conftest.py index 7ec261d6..291eedf5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,11 +1,9 @@ from __future__ import print_function import os -import warnings import pytest -from neomodel import clear_neo4j_database, config, db from neomodel.util import version_tag_to_integer NEO4J_URL = os.environ.get("NEO4J_URL", "bolt://localhost:7687") @@ -30,64 +28,33 @@ def pytest_addoption(parser): @pytest.hookimpl def pytest_collection_modifyitems(items): - connect_to_aura_items = [] - normal_items = [] + async_items = [] + sync_items = [] + async_connect_to_aura_items = [] + sync_connect_to_aura_items = [] - # Separate all tests into two groups: those with "connect_to_aura" in their name, and all others for item in items: + # Check the directory of the item + directory = item.fspath.dirname.split("/")[-1] + if "connect_to_aura" in item.name: - connect_to_aura_items.append(item) + if directory == "async_": + async_connect_to_aura_items.append(item) + elif directory == "sync_": + sync_connect_to_aura_items.append(item) else: - normal_items.append(item) - - # Add all normal tests back to the front of the list - new_order = normal_items - - # Add all connect_to_aura tests to the end of the list - new_order.extend(connect_to_aura_items) - - # Replace the original items list with the new order - items[:] = new_order - - -@pytest.hookimpl -def pytest_sessionstart(session): - """ - Provides initial connection to the database and sets up the rest of the test suite - - :param session: The session object. Please see `_ - :type Session object: For more information please see `_ - """ - - warnings.simplefilter("default") - - config.DATABASE_URL = os.environ.get( - "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" - ) - config.AUTO_INSTALL_LABELS = True - - # Clear the database if required - database_is_populated, _ = db.cypher_query( - "MATCH (a) return count(a)>0 as database_is_populated" - ) - if database_is_populated[0][0] and not session.config.getoption("resetdb"): - raise SystemError( - "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." - ) - else: - clear_neo4j_database(db, clear_constraints=True, clear_indexes=True) - - db.cypher_query( - "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" + if directory == "async_": + async_items.append(item) + elif directory == "sync_": + sync_items.append(item) + + new_order = ( + async_items + + async_connect_to_aura_items + + sync_items + + sync_connect_to_aura_items ) - if db.database_edition == "enterprise": - db.cypher_query("GRANT ROLE publisher TO troygreene") - db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") - - -@pytest.hookimpl -def pytest_unconfigure(): - db.close_connection() + items[:] = new_order def check_and_skip_neo4j_least_version(required_least_neo4j_version, message): @@ -112,8 +79,3 @@ def check_and_skip_neo4j_least_version(required_least_neo4j_version, message): "Neo4j version: {}. {}." "Skipping test.".format(os.environ["NEO4J_VERSION"], message) ) - - -@pytest.fixture -def skip_neo4j_before_330(): - check_and_skip_neo4j_least_version(330, "Neo4J version does not support this test") diff --git a/test/sync_/__init__.py b/test/sync_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py new file mode 100644 index 00000000..d2cd787e --- /dev/null +++ b/test/sync_/conftest.py @@ -0,0 +1,48 @@ +import os +import warnings +from test._async_compat import mark_sync_session_auto_fixture + +from neomodel import config, db + + +@mark_sync_session_auto_fixture +def setup_neo4j_session(request): + """ + Provides initial connection to the database and sets up the rest of the test suite + + :param request: The request object. Please see `_ + :type Request object: For more information please see `_ + """ + + warnings.simplefilter("default") + + config.DATABASE_URL = os.environ.get( + "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687" + ) + + # Clear the database if required + database_is_populated, _ = db.cypher_query( + "MATCH (a) return count(a)>0 as database_is_populated" + ) + if database_is_populated[0][0] and not request.config.getoption("resetdb"): + raise SystemError( + "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb." + ) + + db.clear_neo4j_database(clear_constraints=True, clear_indexes=True) + + db.install_all_labels() + + db.cypher_query( + "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED" + ) + db_edition = db.database_edition + if db_edition == "enterprise": + db.cypher_query("GRANT ROLE publisher TO troygreene") + db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin") + + +@mark_sync_session_auto_fixture +def cleanup(): + yield + db.close_connection() diff --git a/test/test_alias.py b/test/sync_/test_alias.py similarity index 83% rename from test/test_alias.py rename to test/sync_/test_alias.py index c63119aa..420e62a0 100644 --- a/test/test_alias.py +++ b/test/sync_/test_alias.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from neomodel import AliasProperty, StringProperty, StructuredNode @@ -12,12 +14,14 @@ class AliasTestNode(StructuredNode): long_name = MagicProperty(to="name") +@mark_sync_test def test_property_setup_hook(): - tim = AliasTestNode(long_name="tim").save() + timmy = AliasTestNode(long_name="timmy").save() assert AliasTestNode.setup_hook_called - assert tim.name == "tim" + assert timmy.name == "timmy" +@mark_sync_test def test_alias(): jim = AliasTestNode(full_name="Jim").save() assert jim.name == "Jim" diff --git a/test/test_batch.py b/test/sync_/test_batch.py similarity index 79% rename from test/test_batch.py rename to test/sync_/test_batch.py index c2d1ec86..80812d31 100644 --- a/test/test_batch.py +++ b/test/sync_/test_batch.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from pytest import raises from neomodel import ( @@ -9,6 +11,7 @@ UniqueIdProperty, config, ) +from neomodel._async_compat.util import Util from neomodel.exceptions import DeflateError, UniqueProperty config.AUTO_INSTALL_LABELS = True @@ -20,6 +23,7 @@ class UniqueUser(StructuredNode): age = IntegerProperty() +@mark_sync_test def test_unique_id_property_batch(): users = UniqueUser.create({"name": "bob", "age": 2}, {"name": "ben", "age": 3}) @@ -36,6 +40,7 @@ class Customer(StructuredNode): age = IntegerProperty(index=True) +@mark_sync_test def test_batch_create(): users = Customer.create( {"email": "jim1@aol.com", "age": 11}, @@ -51,6 +56,7 @@ def test_batch_create(): assert Customer.nodes.get(email="jim1@aol.com") +@mark_sync_test def test_batch_create_or_update(): users = Customer.create_or_update( {"email": "merge1@aol.com", "age": 11}, @@ -60,7 +66,8 @@ def test_batch_create_or_update(): ) assert len(users) == 4 assert users[1] == users[3] - assert Customer.nodes.get(email="merge1@aol.com").age == 11 + merge_1: Customer = Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 11 more_users = Customer.create_or_update( {"email": "merge1@aol.com", "age": 22}, @@ -68,9 +75,11 @@ def test_batch_create_or_update(): ) assert len(more_users) == 2 assert users[0] == more_users[0] - assert Customer.nodes.get(email="merge1@aol.com").age == 22 + merge_1 = Customer.nodes.get(email="merge1@aol.com") + assert merge_1.age == 22 +@mark_sync_test def test_batch_validation(): # test validation in batch create with raises(DeflateError): @@ -79,8 +88,9 @@ def test_batch_validation(): ) +@mark_sync_test def test_batch_index_violation(): - for u in Customer.nodes.all(): + for u in Customer.nodes: u.delete() users = Customer.create( @@ -94,7 +104,10 @@ def test_batch_index_violation(): ) # not found - assert not Customer.nodes.filter(email="jim7@aol.com") + if Util.is_async_code: + assert not Customer.nodes.filter(email="jim7@aol.com").__bool__() + else: + assert not Customer.nodes.filter(email="jim7@aol.com") class Dog(StructuredNode): @@ -107,11 +120,14 @@ class Person(StructuredNode): pets = RelationshipFrom("Dog", "owner") +@mark_sync_test def test_get_or_create_with_rel(): - bob = Person.get_or_create({"name": "Bob"})[0] + create_bob = Person.get_or_create({"name": "Bob"}) + bob = create_bob[0] bobs_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets) - tim = Person.get_or_create({"name": "Tim"})[0] + create_tim = Person.get_or_create({"name": "Tim"}) + tim = create_tim[0] tims_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets) # not the same gizmo diff --git a/test/test_cardinality.py b/test/sync_/test_cardinality.py similarity index 83% rename from test/test_cardinality.py rename to test/sync_/test_cardinality.py index 3c850db0..9e83762c 100644 --- a/test/test_cardinality.py +++ b/test/sync_/test_cardinality.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from pytest import raises from neomodel import ( @@ -39,44 +41,53 @@ class ToothBrush(StructuredNode): name = StringProperty() +@mark_sync_test def test_cardinality_zero_or_more(): m = Monkey(name="tim").save() assert m.dryers.all() == [] - assert m.dryers.single() is None + single_dryer = m.driver.single() + assert single_dryer is None h = HairDryer(version=1).save() m.dryers.connect(h) assert len(m.dryers.all()) == 1 - assert m.dryers.single().version == 1 + single_dryer = m.dryers.single() + assert single_dryer.version == 1 m.dryers.disconnect(h) assert m.dryers.all() == [] - assert m.dryers.single() is None + single_dryer = m.driver.single() + assert single_dryer is None h2 = HairDryer(version=2).save() m.dryers.connect(h) m.dryers.connect(h2) m.dryers.disconnect_all() assert m.dryers.all() == [] - assert m.dryers.single() is None + single_dryer = m.driver.single() + assert single_dryer is None +@mark_sync_test def test_cardinality_zero_or_one(): m = Monkey(name="bob").save() assert m.driver.all() == [] + single_driver = m.driver.single() assert m.driver.single() is None h = ScrewDriver(version=1).save() m.driver.connect(h) assert len(m.driver.all()) == 1 - assert m.driver.single().version == 1 + single_driver = m.driver.single() + assert single_driver.version == 1 j = ScrewDriver(version=2).save() with raises(AttemptedCardinalityViolation): m.driver.connect(j) m.driver.reconnect(h, j) - assert m.driver.single().version == 2 + single_driver = m.driver.single() + assert single_driver.version == 2 # Forcing creation of a second ToothBrush to go around # AttemptedCardinalityViolation @@ -94,6 +105,7 @@ def test_cardinality_zero_or_one(): m.driver.all() +@mark_sync_test def test_cardinality_one_or_more(): m = Monkey(name="jerry").save() @@ -105,7 +117,8 @@ def test_cardinality_one_or_more(): c = Car(version=2).save() m.car.connect(c) - assert m.car.single().version == 2 + single_car = m.car.single() + assert single_car.version == 2 cars = m.car.all() assert len(cars) == 1 @@ -123,6 +136,7 @@ def test_cardinality_one_or_more(): assert len(cars) == 1 +@mark_sync_test def test_cardinality_one(): m = Monkey(name="jerry").save() @@ -136,9 +150,10 @@ def test_cardinality_one(): b = ToothBrush(name="Jim").save() m.toothbrush.connect(b) - assert m.toothbrush.single().name == "Jim" + single_toothbrush = m.toothbrush.single() + assert single_toothbrush.name == "Jim" - x = ToothBrush(name="Jim").save + x = ToothBrush(name="Jim").save() with raises(AttemptedCardinalityViolation): m.toothbrush.connect(x) diff --git a/test/test_connection.py b/test/sync_/test_connection.py similarity index 87% rename from test/test_connection.py rename to test/sync_/test_connection.py index fb3524bb..e7c0d7ce 100644 --- a/test/test_connection.py +++ b/test/sync_/test_connection.py @@ -1,15 +1,15 @@ import os -import time +from test._async_compat import mark_sync_test +from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME import pytest -from neo4j import GraphDatabase +from neo4j import Driver, GraphDatabase from neo4j.debug import watch from neomodel import StringProperty, StructuredNode, config, db -from .conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME - +@mark_sync_test @pytest.fixture(autouse=True) def setup_teardown(): yield @@ -25,6 +25,7 @@ def neo4j_logging(): yield +@mark_sync_test def get_current_database_name() -> str: """ Fetches the name of the currently active database from the Neo4j database. @@ -42,6 +43,7 @@ class Pastry(StructuredNode): name = StringProperty(unique_index=True) +@mark_sync_test def test_set_connection_driver_works(): # Verify that current connection is up assert Pastry(name="Chocolatine").save() @@ -54,13 +56,16 @@ def test_set_connection_driver_works(): assert Pastry(name="Croissant").save() +@mark_sync_test def test_config_driver_works(): # Verify that current connection is up assert Pastry(name="Chausson aux pommes").save() db.close_connection() # Test connection using a driver defined in config - driver = GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + driver: Driver = GraphDatabase().driver( + NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD) + ) config.DRIVER = driver assert Pastry(name="Grignette").save() @@ -70,11 +75,10 @@ def test_config_driver_works(): config.DRIVER = None -@pytest.mark.skipif( - db.database_edition != "enterprise", - reason="Skipping test for community edition - no multi database in CE", -) +@mark_sync_test def test_connect_to_non_default_database(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition - no multi database in CE") database_name = "pastries" db.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS") db.close_connection() @@ -105,6 +109,7 @@ def test_connect_to_non_default_database(): config.DATABASE_NAME = None +@mark_sync_test @pytest.mark.parametrize( "url", ["bolt://user:password", "http://user:password@localhost:7687"] ) @@ -116,6 +121,7 @@ def test_wrong_url_format(url): db.set_connection(url=url) +@mark_sync_test @pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"]) def test_connect_to_aura(protocol): cypher_return = "hello world" diff --git a/test/sync_/test_contrib/__init__.py b/test/sync_/test_contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_contrib/test_semi_structured.py b/test/sync_/test_contrib/test_semi_structured.py similarity index 87% rename from test/test_contrib/test_semi_structured.py rename to test/sync_/test_contrib/test_semi_structured.py index fe73a2bd..f4b9746b 100644 --- a/test/test_contrib/test_semi_structured.py +++ b/test/sync_/test_contrib/test_semi_structured.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from neomodel import IntegerProperty, StringProperty from neomodel.contrib import SemiStructuredNode @@ -11,11 +13,13 @@ class Dummy(SemiStructuredNode): pass +@mark_sync_test def test_to_save_to_model_with_required_only(): u = UserProf(email="dummy@test.com") assert u.save() +@mark_sync_test def test_save_to_model_with_extras(): u = UserProf(email="jim@test.com", age=3, bar=99) u.foo = True @@ -25,6 +29,7 @@ def test_save_to_model_with_extras(): assert u.bar == 99 +@mark_sync_test def test_save_empty_model(): dummy = Dummy() assert dummy.save() diff --git a/test/sync_/test_contrib/test_spatial_datatypes.py b/test/sync_/test_contrib/test_spatial_datatypes.py new file mode 100644 index 00000000..b35ced3a --- /dev/null +++ b/test/sync_/test_contrib/test_spatial_datatypes.py @@ -0,0 +1,397 @@ +""" +Provides a test case for data types required by issue 374 - "Support for Point property type". + +At the moment, only one new datatype is offered: NeomodelPoint + +For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374 +""" + +import os + +import pytest +import shapely + +import neomodel +import neomodel.contrib.spatial_properties +from neomodel.util import version_tag_to_integer + + +def check_and_skip_neo4j_least_version(required_least_neo4j_version, message): + """ + Checks if the NEO4J_VERSION is at least `required_least_neo4j_version` and skips a test if not. + + WARNING: If the NEO4J_VERSION variable is not set, this function returns True, allowing the test to go ahead. + + :param required_least_neo4j_version: The least version to check. This must be the numberic representation of the + version. That is: '3.4.0' would be passed as 340. + :type required_least_neo4j_version: int + :param message: An informative message as to why the calling test had to be skipped. + :type message: str + :return: A boolean value of True if the version reported is at least `required_least_neo4j_version` + """ + if "NEO4J_VERSION" in os.environ: + if ( + version_tag_to_integer(os.environ["NEO4J_VERSION"]) + < required_least_neo4j_version + ): + pytest.skip( + "Neo4j version: {}. {}." + "Skipping test.".format(os.environ["NEO4J_VERSION"], message) + ) + + +def basic_type_assertions( + ground_truth, tested_object, test_description, check_neo4j_points=False +): + """ + Tests that `tested_object` has been created as intended. + + :param ground_truth: The object as it is supposed to have been created. + :type ground_truth: NeomodelPoint or neo4j.v1.spatial.Point + :param tested_object: The object as it results from one of the contructors. + :type tested_object: NeomodelPoint or neo4j.v1.spatial.Point + :param test_description: A brief description of the test being performed. + :type test_description: str + :param check_neo4j_points: Whether to assert between NeomodelPoint or neo4j.v1.spatial.Point objects. + :type check_neo4j_points: bool + :return: + """ + if check_neo4j_points: + assert isinstance( + tested_object, type(ground_truth) + ), "{} did not return Neo4j Point".format(test_description) + assert ( + tested_object.srid == ground_truth.srid + ), "{} does not have the expected SRID({})".format( + test_description, ground_truth.srid + ) + assert len(tested_object) == len( + ground_truth + ), "Dimensionality mismatch. Expected {}, had {}".format( + len(ground_truth.coords), len(tested_object.coords) + ) + else: + assert isinstance( + tested_object, type(ground_truth) + ), "{} did not return NeomodelPoint".format(test_description) + assert ( + tested_object.crs == ground_truth.crs + ), "{} does not have the expected CRS({})".format( + test_description, ground_truth.crs + ) + assert len(tested_object.coords[0]) == len( + ground_truth.coords[0] + ), "Dimensionality mismatch. Expected {}, had {}".format( + len(ground_truth.coords[0]), len(tested_object.coords[0]) + ) + + +# Object Construction +def test_coord_constructor(): + """ + Tests all the possible ways by which a NeomodelPoint can be instantiated successfully via passing coordinates. + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Implicit cartesian point with coords + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + basic_type_assertions( + ground_truth_object, + new_point, + "Implicit 2d cartesian point instantiation", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0) + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0, 0.0)) + basic_type_assertions( + ground_truth_object, + new_point, + "Implicit 3d cartesian point instantiation", + ) + + # Explicit geographical point with coords + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="wgs-84" + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="wgs-84" + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Explicit 2d geographical point with tuple of coords instantiation", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Explicit 3d geographical point with tuple of coords instantiation", + ) + + # Cartesian point with named arguments + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + x=0.0, y=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(x=0.0, y=0.0) + basic_type_assertions( + ground_truth_object, + new_point, + "Cartesian 2d point with named arguments", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + x=0.0, y=0.0, z=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(x=0.0, y=0.0, z=0.0) + basic_type_assertions( + ground_truth_object, + new_point, + "Cartesian 3d point with named arguments", + ) + + # Geographical point with named arguments + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0 + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Geographical 2d point with named arguments", + ) + + ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0, height=0.0 + ) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + longitude=0.0, latitude=0.0, height=0.0 + ) + basic_type_assertions( + ground_truth_object, + new_point, + "Geographical 3d point with named arguments", + ) + + +def test_copy_constructors(): + """ + Tests all the possible ways by which a NeomodelPoint can be instantiated successfully via a copy constructor call. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Instantiate from Shapely point + + # Implicit cartesian from shapely point + ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0), crs="cartesian" + ) + shapely_point = shapely.geometry.Point((0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(shapely_point) + basic_type_assertions( + ground_truth, new_point, "Implicit cartesian by shapely Point" + ) + + # Explicit geographical by shapely point + ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + shapely_point = shapely.geometry.Point((0.0, 0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + shapely_point, crs="wgs-84-3d" + ) + basic_type_assertions( + ground_truth, new_point, "Explicit geographical by shapely Point" + ) + + # Copy constructor for NeomodelPoints + ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + other_neomodel_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)) + new_point = neomodel.contrib.spatial_properties.NeomodelPoint(other_neomodel_point) + basic_type_assertions(ground_truth, new_point, "NeomodelPoint copy constructor") + + +def test_prohibited_constructor_forms(): + """ + Tests all the possible forms by which construction of NeomodelPoints should fail. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Absurd CRS + with pytest.raises(ValueError, match=r"Invalid CRS\(blue_hotel\)"): + _ = neomodel.contrib.spatial_properties.NeomodelPoint((0, 0), crs="blue_hotel") + + # Absurd coord dimensionality + with pytest.raises( + ValueError, + ): + _ = neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0, 0, 0, 0, 0, 0), crs="cartesian" + ) + + # Absurd datatype passed to copy constructor + with pytest.raises( + TypeError, + ): + _ = neomodel.contrib.spatial_properties.NeomodelPoint( + "it don't mean a thing if it ain't got that swing", + crs="cartesian", + ) + + # Trying to instantiate a point with any of BOTH x,y,z or longitude, latitude, height + with pytest.raises(ValueError, match="Invalid instantiation via arguments"): + _ = neomodel.contrib.spatial_properties.NeomodelPoint( + x=0.0, + y=0.0, + longitude=0.0, + latitude=2.0, + height=-2.0, + crs="cartesian", + ) + + # Trying to instantiate a point with absolutely NO parameters + with pytest.raises(ValueError, match="Invalid instantiation via no arguments"): + _ = neomodel.contrib.spatial_properties.NeomodelPoint() + + +def test_property_accessors_depending_on_crs_shapely_lt_2(): + """ + Tests that points are accessed via their respective accessors. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Check the version of Shapely installed to run the appropriate tests: + try: + from shapely import __version__ + except ImportError: + pytest.skip("Shapely not installed") + + if int("".join(__version__.split(".")[0:3])) >= 200: + pytest.skip("Shapely 2 is installed, skipping earlier version test") + + # Geometrical points only have x,y,z coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="cartesian-3d" + ) + with pytest.raises(AttributeError, match=r'Invalid coordinate \("longitude"\)'): + new_point.longitude + with pytest.raises(AttributeError, match=r'Invalid coordinate \("latitude"\)'): + new_point.latitude + with pytest.raises(AttributeError, match=r'Invalid coordinate \("height"\)'): + new_point.height + + # Geographical points only have longitude, latitude, height coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + with pytest.raises(AttributeError, match=r'Invalid coordinate \("x"\)'): + new_point.x + with pytest.raises(AttributeError, match=r'Invalid coordinate \("y"\)'): + new_point.y + with pytest.raises(AttributeError, match=r'Invalid coordinate \("z"\)'): + new_point.z + + +def test_property_accessors_depending_on_crs_shapely_gte_2(): + """ + Tests that points are accessed via their respective accessors. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Check the version of Shapely installed to run the appropriate tests: + try: + from shapely import __version__ + except ImportError: + pytest.skip("Shapely not installed") + + if int("".join(__version__.split(".")[0:3])) < 200: + pytest.skip("Shapely < 2.0.0 is installed, skipping test") + # Geometrical points only have x,y,z coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="cartesian-3d" + ) + with pytest.raises(TypeError, match=r'Invalid coordinate \("longitude"\)'): + new_point.longitude + with pytest.raises(TypeError, match=r'Invalid coordinate \("latitude"\)'): + new_point.latitude + with pytest.raises(TypeError, match=r'Invalid coordinate \("height"\)'): + new_point.height + + # Geographical points only have longitude, latitude, height coordinates + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 0.0, 0.0), crs="wgs-84-3d" + ) + with pytest.raises(TypeError, match=r'Invalid coordinate \("x"\)'): + new_point.x + with pytest.raises(TypeError, match=r'Invalid coordinate \("y"\)'): + new_point.y + with pytest.raises(TypeError, match=r'Invalid coordinate \("z"\)'): + new_point.z + + +def test_property_accessors(): + """ + Tests that points are accessed via their respective accessors and that these accessors return the right values. + + :return: + """ + + # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test. + check_and_skip_neo4j_least_version( + 340, "This version does not support spatial data types." + ) + + # Geometrical points + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 1.0, 2.0), crs="cartesian-3d" + ) + assert new_point.x == 0.0, "Expected x coordinate to be 0.0" + assert new_point.y == 1.0, "Expected y coordinate to be 1.0" + assert new_point.z == 2.0, "Expected z coordinate to be 2.0" + + # Geographical points + new_point = neomodel.contrib.spatial_properties.NeomodelPoint( + (0.0, 1.0, 2.0), crs="wgs-84-3d" + ) + assert new_point.longitude == 0.0, "Expected longitude to be 0.0" + assert new_point.latitude == 1.0, "Expected latitude to be 1.0" + assert new_point.height == 2.0, "Expected height to be 2.0" diff --git a/test/test_contrib/test_spatial_properties.py b/test/sync_/test_contrib/test_spatial_properties.py similarity index 96% rename from test/test_contrib/test_spatial_properties.py rename to test/sync_/test_contrib/test_spatial_properties.py index 03177c02..f33f4fb6 100644 --- a/test/test_contrib/test_spatial_properties.py +++ b/test/sync_/test_contrib/test_spatial_properties.py @@ -4,8 +4,8 @@ For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374 """ -import os import random +from test._async_compat import mark_sync_test import neo4j.spatial import pytest @@ -156,6 +156,7 @@ def test_deflate(): ) +@mark_sync_test def test_default_value(): """ Tests that the default value passing mechanism works as expected with NeomodelPoint values. @@ -194,6 +195,7 @@ class LocalisableEntity(neomodel.StructuredNode): ), ("Default value assignment failed.") +@mark_sync_test def test_array_of_points(): """ Tests that Arrays of Points work as expected. @@ -236,6 +238,7 @@ class AnotherLocalisableEntity(neomodel.StructuredNode): ], "Array of Points incorrect values." +@mark_sync_test def test_simple_storage_retrieval(): """ Performs a simple Create, Retrieve via .save(), .get() which, due to the way Q objects operate, tests the @@ -264,6 +267,7 @@ class TestStorageRetrievalProperty(neomodel.StructuredNode): assert a_restaurant.description == a_property.description + def test_equality_with_other_objects(): """ Performs equality tests and ensures tha ``NeomodelPoint`` can be compared with ShapelyPoint and NeomodelPoint only. @@ -277,6 +281,9 @@ def test_equality_with_other_objects(): if int("".join(__version__.split(".")[0:3])) < 200: pytest.skip(f"Shapely 2.0 not present (Current version is {__version__}") - assert neomodel.contrib.spatial_properties.NeomodelPoint((0,0)) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0) - assert neomodel.contrib.spatial_properties.NeomodelPoint((0,0)) == shapely.geometry.Point((0,0)) - + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0) + assert neomodel.contrib.spatial_properties.NeomodelPoint( + (0, 0) + ) == shapely.geometry.Point((0, 0)) diff --git a/test/test_cypher.py b/test/sync_/test_cypher.py similarity index 82% rename from test/test_cypher.py rename to test/sync_/test_cypher.py index 7c2e6fd6..944959c4 100644 --- a/test/test_cypher.py +++ b/test/sync_/test_cypher.py @@ -1,12 +1,13 @@ import builtins +from test._async_compat import mark_sync_test import pytest from neo4j.exceptions import ClientError as CypherError from numpy import ndarray from pandas import DataFrame, Series -from neomodel import StringProperty, StructuredNode -from neomodel.core import db +from neomodel import StringProperty, StructuredNode, db +from neomodel._async_compat.util import Util class User2(StructuredNode): @@ -36,6 +37,7 @@ def mocked_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", mocked_import) +@mark_sync_test def test_cypher(): """ test result format is backward compatible with earlier versions of neomodel @@ -58,6 +60,7 @@ def test_cypher(): assert "a" in meta and "b" in meta +@mark_sync_test def test_cypher_syntax_error(): jim = User2(email="jim1@test.com").save() try: @@ -69,8 +72,13 @@ def test_cypher_syntax_error(): assert False, "CypherError not raised." +@mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True) def test_pandas_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when pandas.py is imported + if not Util.is_async_code: + pytest.skip("This test is only") with pytest.raises(ImportError): with pytest.warns( UserWarning, @@ -81,6 +89,7 @@ def test_pandas_not_installed(hide_available_pkg): _ = to_dataframe(db.cypher_query("MATCH (a) RETURN a.name AS name")) +@mark_sync_test def test_pandas_integration(): from neomodel.integration.pandas import to_dataframe, to_series @@ -113,18 +122,24 @@ def test_pandas_integration(): assert df["name"].tolist() == ["jimla", "jimlo"] +@mark_sync_test @pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True) def test_numpy_not_installed(hide_available_pkg): + # We run only the async version, because this fails on second run + # because import error is thrown only when numpy.py is imported + if not Util.is_async_code: + pytest.skip("This test is only") with pytest.raises(ImportError): with pytest.warns( UserWarning, - match="The neomodel.integration.numpy module expects pandas to be installed", + match="The neomodel.integration.numpy module expects numpy to be installed", ): from neomodel.integration.numpy import to_ndarray _ = to_ndarray(db.cypher_query("MATCH (a) RETURN a.name AS name")) +@mark_sync_test def test_numpy_integration(): from neomodel.integration.numpy import to_ndarray @@ -132,9 +147,11 @@ def test_numpy_integration(): jimlu = UserNP(email="jimlu@test.com", name="jimlu").save() array = to_ndarray( - db.cypher_query("MATCH (a:UserNP) RETURN a.name AS name, a.email AS email") + db.cypher_query( + "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email ORDER BY name" + ) ) assert isinstance(array, ndarray) assert array.shape == (2, 2) - assert array[0][0] == "jimly" + assert array[0][0] == "jimlu" diff --git a/test/test_database_management.py b/test/sync_/test_database_management.py similarity index 84% rename from test/test_database_management.py rename to test/sync_/test_database_management.py index 2a2ece34..9f663994 100644 --- a/test/test_database_management.py +++ b/test/sync_/test_database_management.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + import pytest from neo4j.exceptions import AuthError @@ -8,7 +10,6 @@ StructuredNode, StructuredRel, db, - util, ) @@ -26,26 +27,28 @@ class Venue(StructuredNode): in_city = RelationshipTo(City, relation_type="IN", model=InCity) +@mark_sync_test def test_clear_database(): venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save() city = City(name="London").save() venue.in_city.connect(city) # Clear only the data - util.clear_neo4j_database(db) + db.clear_neo4j_database() database_is_populated, _ = db.cypher_query( "MATCH (a) return count(a)>0 as database_is_populated" ) assert database_is_populated[0][0] is False + db.install_all_labels() indexes = db.list_indexes(exclude_token_lookup=True) constraints = db.list_constraints() assert len(indexes) > 0 assert len(constraints) > 0 # Clear constraints and indexes too - util.clear_neo4j_database(db, clear_constraints=True, clear_indexes=True) + db.clear_neo4j_database(clear_constraints=True, clear_indexes=True) indexes = db.list_indexes(exclude_token_lookup=True) constraints = db.list_constraints() @@ -53,13 +56,14 @@ def test_clear_database(): assert len(constraints) == 0 +@mark_sync_test def test_change_password(): prev_password = "foobarbaz" new_password = "newpassword" prev_url = f"bolt://neo4j:{prev_password}@localhost:7687" new_url = f"bolt://neo4j:{new_password}@localhost:7687" - util.change_neo4j_password(db, "neo4j", new_password) + db.change_neo4j_password("neo4j", new_password) db.close_connection() db.set_connection(url=new_url) @@ -71,7 +75,7 @@ def test_change_password(): db.close_connection() db.set_connection(url=new_url) - util.change_neo4j_password(db, "neo4j", prev_password) + db.change_neo4j_password("neo4j", prev_password) db.close_connection() db.set_connection(url=prev_url) diff --git a/test/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py similarity index 69% rename from test/test_dbms_awareness.py rename to test/sync_/test_dbms_awareness.py index 93dc032d..f0f7fb68 100644 --- a/test/test_dbms_awareness.py +++ b/test/sync_/test_dbms_awareness.py @@ -1,14 +1,17 @@ -from pytest import mark +from test._async_compat import mark_sync_test + +import pytest from neomodel import db from neomodel.util import version_tag_to_integer -@mark.skipif( - db.database_version != "5.7.0", reason="Testing a specific database version" -) +@mark_sync_test def test_version_awareness(): - assert db.database_version == "5.7.0" + db_version = db.database_version + if db_version != "5.7.0": + pytest.skip("Testing a specific database version") + assert db_version == "5.7.0" assert db.version_is_higher_than("5.7") assert db.version_is_higher_than("5.6.0") assert db.version_is_higher_than("5") @@ -17,8 +20,10 @@ def test_version_awareness(): assert not db.version_is_higher_than("5.8") +@mark_sync_test def test_edition_awareness(): - if db.database_edition == "enterprise": + db_edition = db.database_edition + if db_edition == "enterprise": assert db.edition_is_enterprise() else: assert not db.edition_is_enterprise() diff --git a/test/test_driver_options.py b/test/sync_/test_driver_options.py similarity index 69% rename from test/test_driver_options.py rename to test/sync_/test_driver_options.py index 26f16640..f244d174 100644 --- a/test/test_driver_options.py +++ b/test/sync_/test_driver_options.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + import pytest from neo4j.exceptions import ClientError from pytest import raises @@ -6,28 +8,28 @@ from neomodel.exceptions import FeatureNotSupported -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) +@mark_sync_test def test_impersonate(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with db.impersonate(user="troygreene"): results, _ = db.cypher_query("RETURN 'Doo Wacko !'") assert results[0][0] == "Doo Wacko !" -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) +@mark_sync_test def test_impersonate_unauthorized(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with db.impersonate(user="unknownuser"): with raises(ClientError): _ = db.cypher_query("RETURN 'Gabagool'") -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) +@mark_sync_test def test_impersonate_multiple_transactions(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") with db.impersonate(user="troygreene"): with db.transaction: results, _ = db.cypher_query("RETURN 'Doo Wacko !'") @@ -41,10 +43,10 @@ def test_impersonate_multiple_transactions(): assert results[0][0] == "neo4j" -@pytest.mark.skipif( - db.edition_is_enterprise(), reason="Skipping test for enterprise edition" -) +@mark_sync_test def test_impersonate_community(): + if db.edition_is_enterprise(): + pytest.skip("Skipping test for enterprise edition") with raises(FeatureNotSupported): with db.impersonate(user="troygreene"): _ = db.cypher_query("RETURN 'Gabagoogoo'") diff --git a/test/test_exceptions.py b/test/sync_/test_exceptions.py similarity index 93% rename from test/test_exceptions.py rename to test/sync_/test_exceptions.py index 546c13fe..fe8cfe36 100644 --- a/test/test_exceptions.py +++ b/test/sync_/test_exceptions.py @@ -1,4 +1,5 @@ import pickle +from test._async_compat import mark_sync_test from neomodel import DoesNotExist, StringProperty, StructuredNode @@ -7,6 +8,7 @@ class EPerson(StructuredNode): name = StringProperty(unique_index=True) +@mark_sync_test def test_object_does_not_exist(): try: EPerson.nodes.get(name="johnny") diff --git a/test/test_hooks.py b/test/sync_/test_hooks.py similarity index 92% rename from test/test_hooks.py rename to test/sync_/test_hooks.py index 158db079..a6f742e0 100644 --- a/test/test_hooks.py +++ b/test/sync_/test_hooks.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from neomodel import StringProperty, StructuredNode HOOKS_CALLED = {} @@ -22,6 +24,7 @@ def post_delete(self): HOOKS_CALLED["post_delete"] = 1 +@mark_sync_test def test_hooks(): ht = HookTest(name="k").save() ht.delete() diff --git a/test/test_indexing.py b/test/sync_/test_indexing.py similarity index 83% rename from test/test_indexing.py rename to test/sync_/test_indexing.py index 0b5e8fba..f39c22ef 100644 --- a/test/test_indexing.py +++ b/test/sync_/test_indexing.py @@ -1,14 +1,9 @@ +from test._async_compat import mark_sync_test + import pytest from pytest import raises -from neomodel import ( - IntegerProperty, - StringProperty, - StructuredNode, - UniqueProperty, - install_labels, -) -from neomodel.core import db +from neomodel import IntegerProperty, StringProperty, StructuredNode, UniqueProperty, db from neomodel.exceptions import ConstraintValidationFailed @@ -17,8 +12,9 @@ class Human(StructuredNode): age = IntegerProperty(index=True) +@mark_sync_test def test_unique_error(): - install_labels(Human) + db.install_labels(Human) Human(name="j1m", age=13).save() try: Human(name="j1m", age=14).save() @@ -29,10 +25,10 @@ def test_unique_error(): assert False, "UniqueProperty not raised." -@pytest.mark.skipif( - not db.edition_is_enterprise(), reason="Skipping test for community edition" -) +@mark_sync_test def test_existence_constraint_error(): + if not db.edition_is_enterprise(): + pytest.skip("Skipping test for community edition") db.cypher_query( "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL" ) @@ -42,6 +38,7 @@ def test_existence_constraint_error(): db.cypher_query("DROP CONSTRAINT test_existence_constraint") +@mark_sync_test def test_optional_properties_dont_get_indexed(): Human(name="99", age=99).save() h = Human.nodes.get(age=99) @@ -54,19 +51,21 @@ def test_optional_properties_dont_get_indexed(): assert h.name == "98" +@mark_sync_test def test_escaped_chars(): _name = "sarah:test" Human(name=_name, age=3).save() r = Human.nodes.filter(name=_name) - assert r assert r[0].name == _name +@mark_sync_test def test_does_not_exist(): with raises(Human.DoesNotExist): Human.nodes.get(name="XXXX") +@mark_sync_test def test_custom_label_name(): class Giraffe(StructuredNode): __label__ = "Giraffes" diff --git a/test/test_issue112.py b/test/sync_/test_issue112.py similarity index 82% rename from test/test_issue112.py rename to test/sync_/test_issue112.py index 3e379932..f580f146 100644 --- a/test/test_issue112.py +++ b/test/sync_/test_issue112.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from neomodel import RelationshipTo, StructuredNode @@ -5,6 +7,7 @@ class SomeModel(StructuredNode): test = RelationshipTo("SomeModel", "SELF") +@mark_sync_test def test_len_relationship(): t1 = SomeModel().save() t2 = SomeModel().save() diff --git a/test/test_issue283.py b/test/sync_/test_issue283.py similarity index 67% rename from test/test_issue283.py rename to test/sync_/test_issue283.py index b7a63f8f..a059f7f2 100644 --- a/test/test_issue283.py +++ b/test/sync_/test_issue283.py @@ -9,14 +9,24 @@ idea remains the same: "Instantiate the correct type of node at the end of a relationship as specified by the model" """ - -import datetime -import os import random +from test._async_compat import mark_sync_test import pytest -import neomodel +from neomodel import ( + DateTimeProperty, + FloatProperty, + RelationshipClassNotDefined, + RelationshipClassRedefined, + RelationshipTo, + StringProperty, + StructuredNode, + StructuredRel, + UniqueIdProperty, + db, +) +from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined try: basestring @@ -25,7 +35,7 @@ # Set up a very simple model for the tests -class PersonalRelationship(neomodel.StructuredRel): +class PersonalRelationship(StructuredRel): """ A very simple relationship between two basePersons that simply records the date at which an acquaintance was established. @@ -33,16 +43,16 @@ class PersonalRelationship(neomodel.StructuredRel): basePerson without any further effort. """ - on_date = neomodel.DateTimeProperty(default_now=True) + on_date = DateTimeProperty(default_now=True) -class BasePerson(neomodel.StructuredNode): +class BasePerson(StructuredNode): """ Base class for defining some basic sort of an actor. """ - name = neomodel.StringProperty(required=True, unique_index=True) - friends_with = neomodel.RelationshipTo( + name = StringProperty(required=True, unique_index=True) + friends_with = RelationshipTo( "BasePerson", "FRIENDS_WITH", model=PersonalRelationship ) @@ -52,7 +62,7 @@ class TechnicalPerson(BasePerson): A Technical person specialises BasePerson by adding their expertise. """ - expertise = neomodel.StringProperty(required=True) + expertise = StringProperty(required=True) class PilotPerson(BasePerson): @@ -61,15 +71,15 @@ class PilotPerson(BasePerson): can operate. """ - airplane = neomodel.StringProperty(required=True) + airplane = StringProperty(required=True) -class BaseOtherPerson(neomodel.StructuredNode): +class BaseOtherPerson(StructuredNode): """ An obviously "wrong" class of actor to befriend BasePersons with. """ - car_color = neomodel.StringProperty(required=True) + car_color = StringProperty(required=True) class SomePerson(BaseOtherPerson): @@ -81,6 +91,7 @@ class SomePerson(BaseOtherPerson): # Test cases +@mark_sync_test def test_automatic_result_resolution(): """ Node objects at the end of relationships are instantiated to their @@ -88,30 +99,29 @@ def test_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happy", "expertise": "Unicorns"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepy", "expertise": "Pillows"} - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0] + C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0] # Add connections A.friends_with.connect(B) B.friends_with.connect(C) C.friends_with.connect(A) + test = A.friends_with + # If A is friends with B, then A's friends_with objects should be # TechnicalPerson (!NOT basePerson!) - assert type(A.friends_with[0]) is TechnicalPerson + assert type((A.friends_with)[0]) is TechnicalPerson A.delete() B.delete() C.delete() +@mark_sync_test def test_recursive_automatic_result_resolution(): """ Node objects are instantiated to native Python objects, both at the top @@ -120,21 +130,17 @@ def test_recursive_automatic_result_resolution(): """ # Create a few entities - A = TechnicalPerson.get_or_create( - {"name": "Grumpier", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happier", "expertise": "Grumpiness"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepier", "expertise": "Pillows"} - )[0] - D = TechnicalPerson.get_or_create( - {"name": "Sneezier", "expertise": "Pillows"} + A = ( + TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"}) )[0] + B = (TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"}))[ + 0 + ] + C = (TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"}))[0] + D = (TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"}))[0] # Retrieve mixed results, both at the top level and nested - L, _ = neomodel.db.cypher_query( + L, _ = db.cypher_query( "MATCH (a:TechnicalPerson) " "WHERE a.expertise='Grumpiness' " "WITH collect(a) as Alpha " @@ -158,6 +164,7 @@ def test_recursive_automatic_result_resolution(): D.delete() +@mark_sync_test def test_validation_with_inheritance_from_db(): """ Objects descending from the specified class of a relationship's end-node are @@ -166,22 +173,22 @@ def test_validation_with_inheritance_from_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happy", "expertise": "Unicorns"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepy", "expertise": "Pillows"} - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0] + C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0] # Pilot Persons - D = PilotPerson.get_or_create( - {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + D = ( + PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) )[0] - E = PilotPerson.get_or_create( - {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + E = ( + PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) )[0] # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine @@ -201,13 +208,13 @@ def test_validation_with_inheritance_from_db(): # This now means that friends_with of a TechnicalPerson can # either be TechnicalPerson or Pilot Person (!NOT basePerson!) - assert (type(A.friends_with[0]) is TechnicalPerson) or ( - type(A.friends_with[0]) is PilotPerson + assert (type((A.friends_with)[0]) is TechnicalPerson) or ( + type((A.friends_with)[0]) is PilotPerson ) - assert (type(A.friends_with[1]) is TechnicalPerson) or ( - type(A.friends_with[1]) is PilotPerson + assert (type((A.friends_with)[1]) is TechnicalPerson) or ( + type((A.friends_with)[1]) is PilotPerson ) - assert type(D.friends_with[0]) is PilotPerson + assert type((D.friends_with)[0]) is PilotPerson A.delete() B.delete() @@ -216,6 +223,7 @@ def test_validation_with_inheritance_from_db(): E.delete() +@mark_sync_test def test_validation_enforcement_to_db(): """ If a connection between wrong types is attempted, raise an exception @@ -223,22 +231,22 @@ def test_validation_enforcement_to_db(): # Create a few entities # Technical Persons - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - B = TechnicalPerson.get_or_create( - {"name": "Happy", "expertise": "Unicorns"} - )[0] - C = TechnicalPerson.get_or_create( - {"name": "Sleepy", "expertise": "Pillows"} - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0] + C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0] # Pilot Persons - D = PilotPerson.get_or_create( - {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + D = ( + PilotPerson.get_or_create( + {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"} + ) )[0] - E = PilotPerson.get_or_create( - {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + E = ( + PilotPerson.get_or_create( + {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"} + ) )[0] # Some Person @@ -265,6 +273,7 @@ def test_validation_enforcement_to_db(): F.delete() +@mark_sync_test def test_failed_result_resolution(): """ A Neo4j driver node FROM the database contains labels that are unaware to @@ -273,39 +282,39 @@ def test_failed_result_resolution(): """ class RandomPerson(BasePerson): - randomness = neomodel.FloatProperty(default=random.random) + randomness = FloatProperty(default=random.random) # A Technical Person... - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] # A Random Person... - B = RandomPerson.get_or_create({"name": "Mad Hatter"})[0] + B = (RandomPerson.get_or_create({"name": "Mad Hatter"}))[0] A.friends_with.connect(B) # Simulate the condition where the definition of class RandomPerson is not # known yet. - del neomodel.db._NODE_CLASS_REGISTRY[ - frozenset(["RandomPerson", "BasePerson"]) - ] + del db._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])] # Now try to instantiate a RandomPerson - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] with pytest.raises( - neomodel.exceptions.NodeClassNotDefined, + NodeClassNotDefined, match=r"Node with labels .* does not resolve to any of the known objects.*", ): - for some_friend in A.friends_with: + friends = A.friends_with.all() + for some_friend in friends: print(some_friend.name) A.delete() B.delete() +@mark_sync_test def test_node_label_mismatch(): """ A Neo4j driver node FROM the database contains a superset of the known @@ -313,23 +322,21 @@ def test_node_label_mismatch(): """ class SuperTechnicalPerson(TechnicalPerson): - superness = neomodel.FloatProperty(default=1.0) + superness = FloatProperty(default=1.0) class UltraTechnicalPerson(SuperTechnicalPerson): - ultraness = neomodel.FloatProperty(default=3.1415928) + ultraness = FloatProperty(default=3.1415928) # Create a TechnicalPerson... - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] # ...that is connected to an UltraTechnicalPerson - F = UltraTechnicalPerson( - name="Chewbaka", expertise="Aarrr wgh ggwaaah" - ).save() + F = UltraTechnicalPerson(name="Chewbaka", expertise="Aarrr wgh ggwaaah").save() A.friends_with.connect(F) # Forget about the UltraTechnicalPerson - del neomodel.db._NODE_CLASS_REGISTRY[ + del db._NODE_CLASS_REGISTRY[ frozenset( [ "UltraTechnicalPerson", @@ -343,17 +350,18 @@ class UltraTechnicalPerson(SuperTechnicalPerson): # Recall a TechnicalPerson and enumerate its friends. # One of them is UltraTechnicalPerson which would be returned as a valid # node to a friends_with query but is currently unknown to the node class registry. - A = TechnicalPerson.get_or_create( - {"name": "Grumpy", "expertise": "Grumpiness"} - )[0] - with pytest.raises(neomodel.exceptions.NodeClassNotDefined): - for some_friend in A.friends_with: + A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[ + 0 + ] + with pytest.raises(NodeClassNotDefined): + friends = A.friends_with.all() + for some_friend in friends: print(some_friend.name) def test_attempted_class_redefinition(): """ - A neomodel.StructuredNode class is attempted to be redefined. + A StructuredNode class is attempted to be redefined. """ def redefine_class_locally(): @@ -361,23 +369,22 @@ def redefine_class_locally(): # SomePerson here. # The internal structure of the SomePerson entity does not matter at all here. class SomePerson(BaseOtherPerson): - uid = neomodel.UniqueIdProperty() + uid = UniqueIdProperty() with pytest.raises( - neomodel.exceptions.NodeClassAlreadyDefined, + NodeClassAlreadyDefined, match=r"Class .* with labels .* already defined:.*", ): redefine_class_locally() +@mark_sync_test def test_relationship_result_resolution(): """ A query returning a "Relationship" object can now instantiate it to a data model class """ # Test specific data - A = PilotPerson( - name="Zantford Granville", airplane="Gee Bee Model R" - ).save() + A = PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save() B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save() C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save() D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save() @@ -388,7 +395,7 @@ def test_relationship_result_resolution(): C.friends_with.connect(D) D.friends_with.connect(E) - query_data = neomodel.db.cypher_query( + query_data = db.cypher_query( "MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) " "WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' " "RETURN DISTINCT r", @@ -399,6 +406,7 @@ def test_relationship_result_resolution(): assert isinstance(query_data[0][0][0], PersonalRelationship) +@mark_sync_test def test_properly_inherited_relationship(): """ A relationship class extends an existing relationship model that must extended the same previously associated @@ -408,11 +416,11 @@ def test_properly_inherited_relationship(): # Extends an existing relationship by adding the "relationship_strength" attribute. # `ExtendedPersonalRelationship` will now substitute `PersonalRelationship` EVERYWHERE in the system. class ExtendedPersonalRelationship(PersonalRelationship): - relationship_strength = neomodel.FloatProperty(default=random.random) + relationship_strength = FloatProperty(default=random.random) # Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson class ExtendedSomePerson(SomePerson): - friends_with = neomodel.RelationshipTo( + friends_with = RelationshipTo( "BaseOtherPerson", "FRIENDS_WITH", model=ExtendedPersonalRelationship, @@ -426,7 +434,7 @@ class ExtendedSomePerson(SomePerson): A.friends_with.connect(B) A.friends_with.connect(C) - query_data = neomodel.db.cypher_query( + query_data = db.cypher_query( "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " "RETURN DISTINCT r", resolve_objects=True, @@ -441,20 +449,21 @@ def test_improperly_inherited_relationship(): :return: """ - class NewRelationship(neomodel.StructuredRel): - profile_match_factor = neomodel.FloatProperty() + class NewRelationship(StructuredRel): + profile_match_factor = FloatProperty() with pytest.raises( - neomodel.RelationshipClassRedefined, + RelationshipClassRedefined, match=r"Relationship of type .* redefined as .*", ): class NewSomePerson(SomePerson): - friends_with = neomodel.RelationshipTo( + friends_with = RelationshipTo( "BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship ) +@mark_sync_test def test_resolve_inexistent_relationship(): """ Attempting to resolve an inexistent relationship should raise an exception @@ -462,13 +471,13 @@ def test_resolve_inexistent_relationship(): """ # Forget about the FRIENDS_WITH Relationship. - del neomodel.db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] + del db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])] with pytest.raises( - neomodel.RelationshipClassNotDefined, + RelationshipClassNotDefined, match=r"Relationship of type .* does not resolve to any of the known objects.*", ): - query_data = neomodel.db.cypher_query( + query_data = db.cypher_query( "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) " "RETURN DISTINCT r", resolve_objects=True, diff --git a/test/test_issue600.py b/test/sync_/test_issue600.py similarity index 61% rename from test/test_issue600.py rename to test/sync_/test_issue600.py index 6851efd2..f6b5a10b 100644 --- a/test/test_issue600.py +++ b/test/sync_/test_issue600.py @@ -4,13 +4,9 @@ The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/600 """ -import datetime -import os -import random +from test._async_compat import mark_sync_test -import pytest - -import neomodel +from neomodel import Relationship, StructuredNode, StructuredRel try: basestring @@ -18,7 +14,7 @@ basestring = str -class Class1(neomodel.StructuredRel): +class Class1(StructuredRel): pass @@ -30,36 +26,37 @@ class SubClass2(Class1): pass -class RelationshipDefinerSecondSibling(neomodel.StructuredNode): - rel_1 = neomodel.Relationship( +class RelationshipDefinerSecondSibling(StructuredNode): + rel_1 = Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1 ) - rel_2 = neomodel.Relationship( + rel_2 = Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1 ) - rel_3 = neomodel.Relationship( + rel_3 = Relationship( "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2 ) -class RelationshipDefinerParentLast(neomodel.StructuredNode): - rel_2 = neomodel.Relationship( +class RelationshipDefinerParentLast(StructuredNode): + rel_2 = Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1 ) - rel_3 = neomodel.Relationship( + rel_3 = Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2 ) - rel_1 = neomodel.Relationship( + rel_1 = Relationship( "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1 ) # Test cases +@mark_sync_test def test_relationship_definer_second_sibling(): # Create a few entities - A = RelationshipDefinerSecondSibling.get_or_create({})[0] - B = RelationshipDefinerSecondSibling.get_or_create({})[0] - C = RelationshipDefinerSecondSibling.get_or_create({})[0] + A = (RelationshipDefinerSecondSibling.get_or_create({}))[0] + B = (RelationshipDefinerSecondSibling.get_or_create({}))[0] + C = (RelationshipDefinerSecondSibling.get_or_create({}))[0] # Add connections A.rel_1.connect(B) @@ -72,11 +69,12 @@ def test_relationship_definer_second_sibling(): C.delete() +@mark_sync_test def test_relationship_definer_parent_last(): # Create a few entities - A = RelationshipDefinerParentLast.get_or_create({})[0] - B = RelationshipDefinerParentLast.get_or_create({})[0] - C = RelationshipDefinerParentLast.get_or_create({})[0] + A = (RelationshipDefinerParentLast.get_or_create({}))[0] + B = (RelationshipDefinerParentLast.get_or_create({}))[0] + C = (RelationshipDefinerParentLast.get_or_create({}))[0] # Add connections A.rel_1.connect(B) diff --git a/test/test_label_drop.py b/test/sync_/test_label_drop.py similarity index 87% rename from test/test_label_drop.py rename to test/sync_/test_label_drop.py index 389d19e0..e4834817 100644 --- a/test/test_label_drop.py +++ b/test/sync_/test_label_drop.py @@ -1,9 +1,8 @@ -from neo4j.exceptions import ClientError +from test._async_compat import mark_sync_test -from neomodel import StringProperty, StructuredNode, config -from neomodel.core import db, remove_all_labels +from neo4j.exceptions import ClientError -config.AUTO_INSTALL_LABELS = True +from neomodel import StringProperty, StructuredNode, db class ConstraintAndIndex(StructuredNode): @@ -11,14 +10,16 @@ class ConstraintAndIndex(StructuredNode): last_name = StringProperty(index=True) +@mark_sync_test def test_drop_labels(): + db.install_labels(ConstraintAndIndex) constraints_before = db.list_constraints() indexes_before = db.list_indexes(exclude_token_lookup=True) assert len(constraints_before) > 0 assert len(indexes_before) > 0 - remove_all_labels() + db.remove_all_labels() constraints = db.list_constraints() indexes = db.list_indexes(exclude_token_lookup=True) diff --git a/test/test_label_install.py b/test/sync_/test_label_install.py similarity index 78% rename from test/test_label_install.py rename to test/sync_/test_label_install.py index 46f55467..14bfe107 100644 --- a/test/test_label_install.py +++ b/test/sync_/test_label_install.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + import pytest from neomodel import ( @@ -6,15 +8,10 @@ StructuredNode, StructuredRel, UniqueIdProperty, - config, - install_all_labels, - install_labels, + db, ) -from neomodel.core import db, drop_constraints from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported -config.AUTO_INSTALL_LABELS = False - class NodeWithIndex(StructuredNode): name = StringProperty(index=True) @@ -47,24 +44,13 @@ class SomeNotUniqueNode(StructuredNode): id_ = UniqueIdProperty(db_property="id") -config.AUTO_INSTALL_LABELS = True - - -def test_labels_were_not_installed(): - bob = NodeWithConstraint(name="bob").save() - bob2 = NodeWithConstraint(name="bob").save() - bob3 = NodeWithConstraint(name="bob").save() - assert bob.element_id != bob3.element_id - - for n in NodeWithConstraint.nodes.all(): - n.delete() - - +@mark_sync_test def test_install_all(): - drop_constraints() - install_labels(AbstractNode) + db.drop_constraints() + db.drop_indexes() + db.install_labels(AbstractNode) # run install all labels - install_all_labels() + db.install_all_labels() indexes = db.list_indexes() index_names = [index["name"] for index in indexes] @@ -79,25 +65,28 @@ def test_install_all(): _drop_constraints_for_label_and_property("NoConstraintsSetup", "name") +@mark_sync_test def test_install_label_twice(capsys): + db.drop_constraints() + db.drop_indexes() expected_std_out = ( "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}" ) - install_labels(AbstractNode) - install_labels(AbstractNode) + db.install_labels(AbstractNode) + db.install_labels(AbstractNode) - install_labels(NodeWithIndex) - install_labels(NodeWithIndex, quiet=False) + db.install_labels(NodeWithIndex) + db.install_labels(NodeWithIndex, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - install_labels(NodeWithConstraint) - install_labels(NodeWithConstraint, quiet=False) + db.install_labels(NodeWithConstraint) + db.install_labels(NodeWithConstraint, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out - install_labels(OtherNodeWithRelationship) - install_labels(OtherNodeWithRelationship, quiet=False) + db.install_labels(OtherNodeWithRelationship) + db.install_labels(OtherNodeWithRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out @@ -111,15 +100,16 @@ class OtherNodeWithUniqueIndexRelationship(StructuredNode): NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship ) - install_labels(OtherNodeWithUniqueIndexRelationship) - install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) + db.install_labels(OtherNodeWithUniqueIndexRelationship) + db.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False) captured = capsys.readouterr() assert expected_std_out in captured.out +@mark_sync_test def test_install_labels_db_property(capsys): - drop_constraints() - install_labels(SomeNotUniqueNode, quiet=False) + db.drop_constraints() + db.install_labels(SomeNotUniqueNode, quiet=False) captured = capsys.readouterr() assert "id" in captured.out # make sure that the id_ constraint doesn't exist @@ -131,8 +121,11 @@ def test_install_labels_db_property(capsys): _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id") -@pytest.mark.skipif(db.version_is_higher_than("5.7"), reason="Not supported before 5.7") +@mark_sync_test def test_relationship_unique_index_not_supported(): + if db.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + class UniqueIndexRelationship(StructuredRel): name = StringProperty(unique_index=True) @@ -150,9 +143,14 @@ class NodeWithUniqueIndexRelationship(StructuredNode): model=UniqueIndexRelationship, ) + db.install_labels(NodeWithUniqueIndexRelationship) -@pytest.mark.skipif(not db.version_is_higher_than("5.7"), reason="Supported from 5.7") + +@mark_sync_test def test_relationship_unique_index(): + if not db.version_is_higher_than("5.7"): + pytest.skip("Not supported before 5.7") + class UniqueIndexRelationshipBis(StructuredRel): name = StringProperty(unique_index=True) @@ -166,7 +164,7 @@ class NodeWithUniqueIndexRelationship(StructuredNode): model=UniqueIndexRelationshipBis, ) - install_labels(UniqueIndexRelationshipBis) + db.install_labels(NodeWithUniqueIndexRelationship) node1 = NodeWithUniqueIndexRelationship().save() node2 = TargetNodeForUniqueIndexRelationship().save() node3 = TargetNodeForUniqueIndexRelationship().save() diff --git a/test/test_match_api.py b/test/sync_/test_match_api.py similarity index 82% rename from test/test_match_api.py rename to test/sync_/test_match_api.py index ee6b337e..fe63badd 100644 --- a/test/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -1,4 +1,5 @@ from datetime import datetime +from test._async_compat import mark_sync_test from pytest import raises @@ -13,8 +14,9 @@ StructuredNode, StructuredRel, ) +from neomodel._async_compat.util import Util from neomodel.exceptions import MultipleNodesReturned -from neomodel.match import NodeSet, Optional, QueryBuilder, Traversal +from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal class SupplierRel(StructuredRel): @@ -45,6 +47,7 @@ class Extension(StructuredNode): extension = RelationshipTo("Extension", "extension") +@mark_sync_test def test_filter_exclude_via_labels(): Coffee(name="Java", price=99).save() @@ -71,6 +74,7 @@ def test_filter_exclude_via_labels(): assert results[0].name == "Kenco" +@mark_sync_test def test_simple_has_via_label(): nescafe = Coffee(name="Nescafe", price=99).save() tesco = Supplier(name="Tesco", delivery_cost=2).save() @@ -91,6 +95,7 @@ def test_simple_has_via_label(): assert "NOT" in qb._ast.where[0] +@mark_sync_test def test_get(): Coffee(name="1", price=3).save() assert Coffee.nodes.get(name="1") @@ -104,6 +109,7 @@ def test_get(): Coffee.nodes.get(price=3) +@mark_sync_test def test_simple_traverse_with_filter(): nescafe = Coffee(name="Nescafe2", price=99).save() tesco = Supplier(name="Sainsburys", delivery_cost=2).save() @@ -111,7 +117,8 @@ def test_simple_traverse_with_filter(): qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())) - results = qb.build_ast()._execute() + _ast = qb.build_ast() + results = _ast._execute() assert qb._ast.lookup assert qb._ast.match @@ -120,6 +127,7 @@ def test_simple_traverse_with_filter(): assert results[0].name == "Sainsburys" +@mark_sync_test def test_double_traverse(): nescafe = Coffee(name="Nescafe plus", price=99).save() tesco = Supplier(name="Asda", delivery_cost=2).save() @@ -135,19 +143,23 @@ def test_double_traverse(): assert results[1].name == "Nescafe plus" +@mark_sync_test def test_count(): Coffee(name="Nescafe Gold", price=99).save() - count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count() + ast = QueryBuilder(NodeSet(source=Coffee)).build_ast() + count = ast._count() assert count > 0 Coffee(name="Kawa", price=27).save() node_set = NodeSet(source=Coffee) node_set.skip = 1 node_set.limit = 1 - count = QueryBuilder(node_set).build_ast()._count() + ast = QueryBuilder(node_set).build_ast() + count = ast._count() assert count == 1 +@mark_sync_test def test_len_and_iter_and_bool(): iterations = 0 @@ -162,6 +174,7 @@ def test_len_and_iter_and_bool(): assert len(Coffee.nodes) == 0 +@mark_sync_test def test_slice(): for c in Coffee.nodes: c.delete() @@ -170,13 +183,22 @@ def test_slice(): Coffee(name="Britains finest").save() Coffee(name="Japans finest").save() - assert len(list(Coffee.nodes.all()[1:])) == 2 - assert len(list(Coffee.nodes.all()[:1])) == 1 - assert isinstance(Coffee.nodes[1], Coffee) - assert isinstance(Coffee.nodes[0], Coffee) - assert len(list(Coffee.nodes.all()[1:2])) == 1 - - + # Branching tests because async needs extra brackets + if Util.is_async_code: + assert len(list((Coffee.nodes)[1:])) == 2 + assert len(list((Coffee.nodes)[:1])) == 1 + assert isinstance((Coffee.nodes)[1], Coffee) + assert isinstance((Coffee.nodes)[0], Coffee) + assert len(list((Coffee.nodes)[1:2])) == 1 + else: + assert len(list(Coffee.nodes[1:])) == 2 + assert len(list(Coffee.nodes[:1])) == 1 + assert isinstance(Coffee.nodes[1], Coffee) + assert isinstance(Coffee.nodes[0], Coffee) + assert len(list(Coffee.nodes[1:2])) == 1 + + +@mark_sync_test def test_issue_208(): # calls to match persist across queries. @@ -191,13 +213,16 @@ def test_issue_208(): assert len(b.suppliers.match(courier="dhl")) +@mark_sync_test def test_issue_589(): node1 = Extension().save() node2 = Extension().save() + assert node2 not in node1.extension node1.extension.connect(node2) - assert node2 in node1.extension.all() + assert node2 in node1.extension +@mark_sync_test def test_contains(): expensive = Coffee(price=1000, name="Pricey").save() asda = Coffee(name="Asda", price=1).save() @@ -206,14 +231,21 @@ def test_contains(): assert asda not in Coffee.nodes.filter(price__gt=999) # bad value raises - with raises(ValueError): - 2 in Coffee.nodes + with raises(ValueError, match=r"Expecting StructuredNode instance"): + if Util.is_async_code: + assert Coffee.nodes.__contains__(2) + else: + assert 2 in Coffee.nodes # unsaved - with raises(ValueError): - Coffee() in Coffee.nodes + with raises(ValueError, match=r"Unsaved node"): + if Util.is_async_code: + assert Coffee.nodes.__contains__(Coffee()) + else: + assert Coffee() in Coffee.nodes +@mark_sync_test def test_order_by(): for c in Coffee.nodes: c.delete() @@ -222,8 +254,12 @@ def test_order_by(): c2 = Coffee(name="Britains finest", price=10).save() c3 = Coffee(name="Japans finest", price=35).save() - assert Coffee.nodes.order_by("price").all()[0].price == 5 - assert Coffee.nodes.order_by("-price").all()[0].price == 35 + if Util.is_async_code: + assert ((Coffee.nodes.order_by("price"))[0]).price == 5 + assert ((Coffee.nodes.order_by("-price"))[0]).price == 35 + else: + assert (Coffee.nodes.order_by("price")[0]).price == 5 + assert (Coffee.nodes.order_by("-price")[0]).price == 35 ns = Coffee.nodes.order_by("-price") qb = QueryBuilder(ns).build_ast() @@ -248,12 +284,13 @@ def test_order_by(): l.coffees.connect(c2) l.coffees.connect(c3) - ordered_n = [n for n in l.coffees.order_by("name").all()] + ordered_n = [n for n in l.coffees.order_by("name")] assert ordered_n[0] == c2 assert ordered_n[1] == c1 assert ordered_n[2] == c3 +@mark_sync_test def test_extra_filters(): for c in Coffee.nodes: c.delete() @@ -263,22 +300,22 @@ def test_extra_filters(): c3 = Coffee(name="Japans finest", price=35, id_=3).save() c4 = Coffee(name="US extra-fine", price=None, id_=4).save() - coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]).all() + coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]) assert len(coffees_5_10) == 2, "unexpected number of results" assert c1 in coffees_5_10, "doesnt contain 5 price coffee" assert c2 in coffees_5_10, "doesnt contain 10 price coffee" - finest_coffees = Coffee.nodes.filter(name__iendswith=" Finest").all() + finest_coffees = Coffee.nodes.filter(name__iendswith=" Finest") assert len(finest_coffees) == 3, "unexpected number of results" assert c1 in finest_coffees, "doesnt contain 1st finest coffee" assert c2 in finest_coffees, "doesnt contain 2nd finest coffee" assert c3 in finest_coffees, "doesnt contain 3rd finest coffee" - unpriced_coffees = Coffee.nodes.filter(price__isnull=True).all() + unpriced_coffees = Coffee.nodes.filter(price__isnull=True) assert len(unpriced_coffees) == 1, "unexpected number of results" assert c4 in unpriced_coffees, "doesnt contain unpriced coffee" - coffees_with_id_gte_3 = Coffee.nodes.filter(id___gte=3).all() + coffees_with_id_gte_3 = Coffee.nodes.filter(id___gte=3) assert len(coffees_with_id_gte_3) == 2, "unexpected number of results" assert c3 in coffees_with_id_gte_3 assert c4 in coffees_with_id_gte_3 @@ -317,6 +354,7 @@ def test_traversal_definition_keys_are_valid(): ) +@mark_sync_test def test_empty_filters(): """Test this case: ``` @@ -353,6 +391,7 @@ def test_empty_filters(): ), "doesnt contain c1 in ``filter_empty_filter``" +@mark_sync_test def test_q_filters(): # Test where no children and self.connector != conn ? for c in Coffee.nodes: @@ -418,7 +457,7 @@ def test_q_filters(): combined_coffees = Coffee.nodes.filter( Q(price=35), Q(name="Latte") | Q(name="Cappuccino") - ) + ).all() assert len(combined_coffees) == 2 assert c5 in combined_coffees assert c6 in combined_coffees @@ -443,6 +482,7 @@ def test_qbase(): assert len(test_hash) == 1 +@mark_sync_test def test_traversal_filter_left_hand_statement(): nescafe = Coffee(name="Nescafe2", price=99).save() nescafe_gold = Coffee(name="Nescafe gold", price=11).save() @@ -462,6 +502,7 @@ def test_traversal_filter_left_hand_statement(): assert lidl in lidl_supplier +@mark_sync_test def test_fetch_relations(): arabica = Species(name="Arabica").save() robusta = Species(name="Robusta").save() @@ -491,14 +532,27 @@ def test_fetch_relations(): ) assert result[0][0] is None - # len() should only consider Suppliers - count = len( - Supplier.nodes.filter(name="Sainsburys") - .fetch_relations("coffees__species") - .all() - ) - assert count == 1 + if Util.is_async_code: + count = ( + Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .__len__() + ) + assert count == 1 - assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( - name="Sainsburys" - ) + assert ( + Supplier.nodes.fetch_relations("coffees__species") + .filter(name="Sainsburys") + .__contains__(tesco) + ) + else: + count = len( + Supplier.nodes.filter(name="Sainsburys") + .fetch_relations("coffees__species") + .all() + ) + assert count == 1 + + assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter( + name="Sainsburys" + ) diff --git a/test/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py similarity index 84% rename from test/test_migration_neo4j_5.py rename to test/sync_/test_migration_neo4j_5.py index a4730329..83e090cb 100644 --- a/test/test_migration_neo4j_5.py +++ b/test/sync_/test_migration_neo4j_5.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + import pytest from neomodel import ( @@ -6,8 +8,8 @@ StringProperty, StructuredNode, StructuredRel, + db, ) -from neomodel.core import db class Album(StructuredNode): @@ -23,25 +25,27 @@ class Band(StructuredNode): released = RelationshipTo(Album, relation_type="RELEASED", model=Released) +@mark_sync_test def test_read_elements_id(): the_hives = Band(name="The Hives").save() lex_hives = Album(name="Lex Hives").save() released_rel = the_hives.released.connect(lex_hives) # Validate element_id properties - assert lex_hives.element_id == the_hives.released.single().element_id + assert lex_hives.element_id == (the_hives.released.single()).element_id assert released_rel._start_node_element_id == the_hives.element_id assert released_rel._end_node_element_id == lex_hives.element_id # Validate id properties # Behaviour is dependent on Neo4j version - if db.database_version.startswith("4"): + db_version = db.database_version + if db_version.startswith("4"): # Nodes' ids assert lex_hives.id == int(lex_hives.element_id) - assert lex_hives.id == the_hives.released.single().id + assert lex_hives.id == (the_hives.released.single()).id # Relationships' ids - assert isinstance(released_rel.element_id, int) - assert released_rel.element_id == released_rel.id + assert isinstance(released_rel.element_id, str) + assert int(released_rel.element_id) == released_rel.id assert released_rel._start_node_id == int(the_hives.element_id) assert released_rel._end_node_id == int(lex_hives.element_id) else: diff --git a/test/test_models.py b/test/sync_/test_models.py similarity index 90% rename from test/test_models.py rename to test/sync_/test_models.py index 827c705a..3698b612 100644 --- a/test/test_models.py +++ b/test/sync_/test_models.py @@ -1,6 +1,7 @@ from __future__ import print_function from datetime import datetime +from test._async_compat import mark_sync_test from pytest import raises @@ -10,9 +11,8 @@ StringProperty, StructuredNode, StructuredRel, - install_labels, + db, ) -from neomodel.core import db from neomodel.exceptions import RequiredProperty, UniqueProperty @@ -33,6 +33,7 @@ class NodeWithoutProperty(StructuredNode): pass +@mark_sync_test def test_issue_233(): class BaseIssue233(StructuredNode): __abstract_node__ = True @@ -52,22 +53,19 @@ def test_issue_72(): assert user.age is None +@mark_sync_test def test_required(): - try: + with raises(RequiredProperty): User(age=3).save() - except RequiredProperty: - assert True - else: - assert False def test_repr_and_str(): u = User(email="robin@test.com", age=3) - print(repr(u)) - print(str(u)) - assert True + assert repr(u) == "" + assert str(u) == "{'email': 'robin@test.com', 'age': 3}" +@mark_sync_test def test_get_and_get_or_none(): u = User(email="robin@test.com", age=3) assert u.save() @@ -82,6 +80,7 @@ def test_get_and_get_or_none(): assert n is None +@mark_sync_test def test_first_and_first_or_none(): u = User(email="matt@test.com", age=24) assert u.save() @@ -103,9 +102,10 @@ def test_bare_init_without_save(): If a node model is initialised without being saved, accessing its `element_id` should return None. """ - assert(User().element_id is None) + assert User().element_id is None +@mark_sync_test def test_save_to_model(): u = User(email="jim@test.com", age=3) assert u.save() @@ -114,24 +114,28 @@ def test_save_to_model(): assert u.age == 3 +@mark_sync_test def test_save_node_without_properties(): n = NodeWithoutProperty() assert n.save() assert n.element_id is not None +@mark_sync_test def test_unique(): - install_labels(User) + db.install_labels(User) User(email="jim1@test.com", age=3).save() with raises(UniqueProperty): User(email="jim1@test.com", age=3).save() +@mark_sync_test def test_update_unique(): u = User(email="jimxx@test.com", age=3).save() u.save() # this shouldn't fail +@mark_sync_test def test_update(): user = User(email="jim2@test.com", age=3).save() assert user @@ -142,6 +146,7 @@ def test_update(): assert jim.email == "jim2000@test.com" +@mark_sync_test def test_save_through_magic_property(): user = User(email_alias="blah@test.com", age=8).save() assert user.email_alias == "blah@test.com" @@ -163,19 +168,21 @@ class Customer2(StructuredNode): age = IntegerProperty(index=True) +@mark_sync_test def test_not_updated_on_unique_error(): - install_labels(Customer2) + db.install_labels(Customer2) Customer2(email="jim@bob.com", age=7).save() test = Customer2(email="jim1@bob.com", age=2).save() test.email = "jim@bob.com" with raises(UniqueProperty): test.save() - customers = Customer2.nodes.all() + customers = Customer2.nodes assert customers[0].email != customers[1].email - assert Customer2.nodes.get(email="jim@bob.com").age == 7 - assert Customer2.nodes.get(email="jim1@bob.com").age == 2 + assert (Customer2.nodes.get(email="jim@bob.com")).age == 7 + assert (Customer2.nodes.get(email="jim1@bob.com")).age == 2 +@mark_sync_test def test_label_not_inherited(): class Customer3(Customer2): address = StringProperty() @@ -191,6 +198,7 @@ class Customer3(Customer2): assert "Customer3" in c.labels() +@mark_sync_test def test_refresh(): c = Customer2(email="my@email.com", age=16).save() c.my_custom_prop = "value" @@ -210,7 +218,8 @@ def test_refresh(): assert c.age == 20 - if db.database_version.startswith("4"): + _db_version = db.database_version + if _db_version.startswith("4"): c = Customer2.inflate(999) else: c = Customer2.inflate("4:xxxxxx:999") @@ -218,6 +227,7 @@ def test_refresh(): c.refresh() +@mark_sync_test def test_setting_value_to_none(): c = Customer2(email="alice@bob.com", age=42).save() assert c.age is not None @@ -229,6 +239,7 @@ def test_setting_value_to_none(): assert copy.age is None +@mark_sync_test def test_inheritance(): class User(StructuredNode): __abstract_node__ = True @@ -248,9 +259,10 @@ def credit_account(self, amount): assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 assert len(jim.labels()) == 1 - assert jim.labels()[0] == "Shopper" + assert (jim.labels())[0] == "Shopper" +@mark_sync_test def test_inherited_optional_labels(): class BaseOptional(StructuredNode): __optional_labels__ = ["Alive"] @@ -275,6 +287,7 @@ def credit_account(self, amount): assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"} +@mark_sync_test def test_mixins(): class UserMixin: name = StringProperty(unique_index=True) @@ -297,9 +310,10 @@ class Shopper2(StructuredNode, UserMixin, CreditMixin): assert jim.balance == 350 assert len(jim.inherited_labels()) == 1 assert len(jim.labels()) == 1 - assert jim.labels()[0] == "Shopper2" + assert (jim.labels())[0] == "Shopper2" +@mark_sync_test def test_date_property(): class DateTest(StructuredNode): birthdate = DateProperty() diff --git a/test/test_multiprocessing.py b/test/sync_/test_multiprocessing.py similarity index 78% rename from test/test_multiprocessing.py rename to test/sync_/test_multiprocessing.py index fb00675d..2d9167f9 100644 --- a/test/test_multiprocessing.py +++ b/test/sync_/test_multiprocessing.py @@ -1,4 +1,5 @@ from multiprocessing.pool import ThreadPool as Pool +from test._async_compat import mark_sync_test from neomodel import StringProperty, StructuredNode, db @@ -13,9 +14,11 @@ def thing_create(name): return thing.name, name +@mark_sync_test def test_concurrency(): with Pool(5) as p: results = p.map(thing_create, range(50)) - for returned, sent in results: + for to_unpack in results: + returned, sent = to_unpack assert returned == sent db.close_connection() diff --git a/test/test_paths.py b/test/sync_/test_paths.py similarity index 67% rename from test/test_paths.py rename to test/sync_/test_paths.py index 8c6fef28..8e0ccf90 100644 --- a/test/test_paths.py +++ b/test/sync_/test_paths.py @@ -1,41 +1,56 @@ -from neomodel import (StringProperty, StructuredNode, UniqueIdProperty, - db, RelationshipTo, IntegerProperty, NeomodelPath, StructuredRel) +from test._async_compat import mark_sync_test + +from neomodel import ( + IntegerProperty, + NeomodelPath, + RelationshipTo, + StringProperty, + StructuredNode, + StructuredRel, + UniqueIdProperty, + db, +) + class PersonLivesInCity(StructuredRel): """ Relationship with data that will be instantiated as "stand-alone" """ + some_num = IntegerProperty(index=True, default=12) + class CountryOfOrigin(StructuredNode): code = StringProperty(unique_index=True, required=True) + class CityOfResidence(StructuredNode): name = StringProperty(required=True) - country = RelationshipTo(CountryOfOrigin, 'FROM_COUNTRY') + country = RelationshipTo(CountryOfOrigin, "FROM_COUNTRY") + class PersonOfInterest(StructuredNode): uid = UniqueIdProperty() name = StringProperty(unique_index=True) age = IntegerProperty(index=True, default=0) - country = RelationshipTo(CountryOfOrigin, 'IS_FROM') - city = RelationshipTo(CityOfResidence, 'LIVES_IN', model=PersonLivesInCity) + country = RelationshipTo(CountryOfOrigin, "IS_FROM") + city = RelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity) +@mark_sync_test def test_path_instantiation(): """ - Neo4j driver paths should be instantiated as neomodel paths, with all of - their nodes and relationships resolved to their Python objects wherever + Neo4j driver paths should be instantiated as neomodel paths, with all of + their nodes and relationships resolved to their Python objects wherever such a mapping is available. """ - c1=CountryOfOrigin(code="GR").save() - c2=CountryOfOrigin(code="FR").save() - - ct1 = CityOfResidence(name="Athens", country = c1).save() - ct2 = CityOfResidence(name="Paris", country = c2).save() + c1 = CountryOfOrigin(code="GR").save() + c2 = CountryOfOrigin(code="FR").save() + ct1 = CityOfResidence(name="Athens", country=c1).save() + ct2 = CityOfResidence(name="Paris", country=c2).save() p1 = PersonOfInterest(name="Bill", age=22).save() p1.country.connect(c1) @@ -54,7 +69,10 @@ def test_path_instantiation(): p4.city.connect(ct2) # Retrieve a single path - q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects = True) + q = db.cypher_query( + "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", + resolve_objects=True, + ) path_object = q[0][0][0] path_nodes = path_object.nodes @@ -76,4 +94,3 @@ def test_path_instantiation(): p2.delete() p3.delete() p4.delete() - diff --git a/test/test_properties.py b/test/sync_/test_properties.py similarity index 94% rename from test/test_properties.py rename to test/sync_/test_properties.py index 454ada26..28866738 100644 --- a/test/test_properties.py +++ b/test/sync_/test_properties.py @@ -1,9 +1,10 @@ from datetime import date, datetime +from test._async_compat import mark_sync_test from pytest import mark, raises from pytz import timezone -from neomodel import StructuredNode, config, db +from neomodel import StructuredNode, db from neomodel.exceptions import ( DeflateError, InflateError, @@ -25,8 +26,6 @@ ) from neomodel.util import _get_node_properties -config.AUTO_INSTALL_LABELS = True - class FooBar: pass @@ -60,6 +59,7 @@ def test_string_property_exceeds_max_length(): ), "StringProperty max_length test passed but values do not match." +@mark_sync_test def test_string_property_w_choice(): class TestChoices(StructuredNode): SEXES = {"F": "Female", "M": "Male", "O": "Other"} @@ -84,7 +84,6 @@ def test_deflate_inflate(): try: prop.inflate("six") except InflateError as e: - assert True assert "inflate property" in str(e) else: assert False, "DeflateError not raised." @@ -185,6 +184,7 @@ def test_json(): assert prop.inflate('{"test": [1, 2, 3]}') == value +@mark_sync_test def test_default_value(): class DefaultTestValue(StructuredNode): name_xx = StringProperty(default="jim", index=True) @@ -194,6 +194,7 @@ class DefaultTestValue(StructuredNode): a.save() +@mark_sync_test def test_default_value_callable(): def uid_generator(): return "xx" @@ -205,6 +206,7 @@ class DefaultTestValueTwo(StructuredNode): assert a.uid == "xx" +@mark_sync_test def test_default_value_callable_type(): # check our object gets converted to str without serializing and reload def factory(): @@ -225,6 +227,7 @@ class DefaultTestValueThree(StructuredNode): assert x.uid == "123" +@mark_sync_test def test_independent_property_name(): class TestDBNamePropertyNode(StructuredNode): name_ = StringProperty(db_property="name") @@ -242,12 +245,13 @@ class TestDBNamePropertyNode(StructuredNode): assert not "name_" in node_properties assert not hasattr(x, "name") assert hasattr(x, "name_") - assert TestDBNamePropertyNode.nodes.filter(name_="jim").all()[0].name_ == x.name_ - assert TestDBNamePropertyNode.nodes.get(name_="jim").name_ == x.name_ + assert (TestDBNamePropertyNode.nodes.filter(name_="jim").all())[0].name_ == x.name_ + assert (TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_ x.delete() +@mark_sync_test def test_independent_property_name_get_or_create(): class TestNode(StructuredNode): uid = UniqueIdProperty() @@ -256,10 +260,10 @@ class TestNode(StructuredNode): # create the node TestNode.get_or_create({"uid": 123, "name_": "jim"}) # test that the node is retrieved correctly - x = TestNode.get_or_create({"uid": 123, "name_": "jim"})[0] + x = (TestNode.get_or_create({"uid": 123, "name_": "jim"}))[0] # check database property name on low level - results, meta = db.cypher_query("MATCH (n:TestNode) RETURN n") + results, _ = db.cypher_query("MATCH (n:TestNode) RETURN n") node_properties = _get_node_properties(results[0][0]) assert node_properties["name"] == "jim" assert "name_" not in node_properties @@ -331,6 +335,7 @@ def test_email_property(): prop.deflate("foo@example") +@mark_sync_test def test_uid_property(): prop = UniqueIdProperty() prop.name = "uid" @@ -351,6 +356,7 @@ class ArrayProps(StructuredNode): typed_arr = ArrayProperty(IntegerProperty()) +@mark_sync_test def test_array_properties(): # untyped ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save() @@ -377,6 +383,7 @@ def test_illegal_array_base_prop_raises(): ArrayProperty(StringProperty(index=True)) +@mark_sync_test def test_indexed_array(): class IndexArray(StructuredNode): ai = ArrayProperty(unique_index=True) @@ -386,6 +393,7 @@ class IndexArray(StructuredNode): assert b.element_id == c.element_id +@mark_sync_test def test_unique_index_prop_not_required(): class ConstrainedTestNode(StructuredNode): required_property = StringProperty(required=True) @@ -414,10 +422,12 @@ class ConstrainedTestNode(StructuredNode): x.delete() +@mark_sync_test def test_unique_index_prop_enforced(): class UniqueNullableNameNode(StructuredNode): name = StringProperty(unique_index=True) + db.install_labels(UniqueNullableNameNode) # Nameless x = UniqueNullableNameNode() x.save() @@ -432,7 +442,7 @@ class UniqueNullableNameNode(StructuredNode): a.save() # Check nodes are in database - results, meta = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") + results, _ = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n") assert len(results) == 3 # Delete nodes afterwards diff --git a/test/test_relationship_models.py b/test/sync_/test_relationship_models.py similarity index 90% rename from test/test_relationship_models.py rename to test/sync_/test_relationship_models.py index 82760c73..837f53d6 100644 --- a/test/test_relationship_models.py +++ b/test/sync_/test_relationship_models.py @@ -1,4 +1,5 @@ from datetime import datetime +from test._async_compat import mark_sync_test import pytz from pytest import raises @@ -12,6 +13,7 @@ StructuredNode, StructuredRel, ) +from neomodel._async_compat.util import Util HOOKS_CALLED = {"pre_save": 0, "post_save": 0} @@ -41,6 +43,7 @@ class Stoat(StructuredNode): hates = RelationshipTo("Badger", "HATES", model=HatesRel) +@mark_sync_test def test_either_connect_with_rel_model(): paul = Badger(name="Paul").save() tom = Badger(name="Tom").save() @@ -63,6 +66,7 @@ def test_either_connect_with_rel_model(): assert tom.name == "Paul" +@mark_sync_test def test_direction_connect_with_rel_model(): paul = Badger(name="Paul the badger").save() ian = Stoat(name="Ian the stoat").save() @@ -103,6 +107,7 @@ def test_direction_connect_with_rel_model(): ) +@mark_sync_test def test_traversal_where_clause(): phill = Badger(name="Phill the badger").save() tim = Badger(name="Tim the badger").save() @@ -113,9 +118,10 @@ def test_traversal_where_clause(): rel2 = tim.friend.connect(phill) assert rel2.since > now friends = tim.friend.match(since__gt=now) - assert len(friends) == 1 + assert len(friends.all()) == 1 +@mark_sync_test def test_multiple_rels_exist_issue_223(): # check a badger can dislike a stoat for multiple reasons phill = Badger(name="Phill").save() @@ -125,11 +131,16 @@ def test_multiple_rels_exist_issue_223(): rel_b = phill.hates.connect(ian, {"reason": "b"}) assert rel_a.element_id != rel_b.element_id - ian_a = phill.hates.match(reason="a")[0] - ian_b = phill.hates.match(reason="b")[0] + if Util.is_async_code: + ian_a = (phill.hates.match(reason="a"))[0] + ian_b = (phill.hates.match(reason="b"))[0] + else: + ian_a = phill.hates.match(reason="a")[0] + ian_b = phill.hates.match(reason="b")[0] assert ian_a.element_id == ian_b.element_id +@mark_sync_test def test_retrieve_all_rels(): tom = Badger(name="tom").save() ian = Stoat(name="ian").save() @@ -143,6 +154,7 @@ def test_retrieve_all_rels(): assert rels[1].element_id in [rel_a.element_id, rel_b.element_id] +@mark_sync_test def test_save_hook_on_rel_model(): HOOKS_CALLED["pre_save"] = 0 HOOKS_CALLED["post_save"] = 0 diff --git a/test/test_relationships.py b/test/sync_/test_relationships.py similarity index 91% rename from test/test_relationships.py rename to test/sync_/test_relationships.py index 92d75064..39057a39 100644 --- a/test/test_relationships.py +++ b/test/sync_/test_relationships.py @@ -1,3 +1,5 @@ +from test._async_compat import mark_sync_test + from pytest import raises from neomodel import ( @@ -10,8 +12,8 @@ StringProperty, StructuredNode, StructuredRel, + db, ) -from neomodel.core import db class PersonWithRels(StructuredNode): @@ -41,6 +43,7 @@ def special_power(self): return "I have powers" +@mark_sync_test def test_actions_on_deleted_node(): u = PersonWithRels(name="Jim2", age=3).save() u.delete() @@ -54,6 +57,7 @@ def test_actions_on_deleted_node(): u.save() +@mark_sync_test def test_bidirectional_relationships(): u = PersonWithRels(name="Jim", age=3).save() assert u @@ -61,26 +65,27 @@ def test_bidirectional_relationships(): de = Country(code="DE").save() assert de - assert not u.is_from + assert not u.is_from.all() assert u.is_from.__class__.__name__ == "ZeroOrMore" u.is_from.connect(de) - assert len(u.is_from) == 1 + assert len(u.is_from.all()) == 1 assert u.is_from.is_connected(de) - b = u.is_from.all()[0] + b = (u.is_from.all())[0] assert b.__class__.__name__ == "Country" assert b.code == "DE" - s = b.inhabitant.all()[0] + s = (b.inhabitant.all())[0] assert s.name == "Jim" u.is_from.disconnect(b) assert not u.is_from.is_connected(b) +@mark_sync_test def test_either_direction_connect(): rey = PersonWithRels(name="Rey", age=3).save() sakis = PersonWithRels(name="Sakis", age=3).save() @@ -94,7 +99,7 @@ def test_either_direction_connect(): f"""MATCH (us), (them) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(them)=$them MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""", - {"them": rey.element_id}, + {"them": db.parse_element_id(rey.element_id)}, ) assert int(result[0][0]) == 1 @@ -105,6 +110,7 @@ def test_either_direction_connect(): assert isinstance(rels[0], StructuredRel) +@mark_sync_test def test_search_and_filter_and_exclude(): fred = PersonWithRels(name="Fred", age=13).save() zz = Country(code="ZZ").save() @@ -129,6 +135,7 @@ def test_search_and_filter_and_exclude(): assert len(result) == 3 +@mark_sync_test def test_custom_methods(): u = PersonWithRels(name="Joe90", age=13).save() assert u.special_power() == "I have no powers" @@ -137,6 +144,7 @@ def test_custom_methods(): assert u.special_name == "Joe91" +@mark_sync_test def test_valid_reconnection(): p = PersonWithRels(name="ElPresidente", age=93).save() assert p @@ -159,6 +167,7 @@ def test_valid_reconnection(): assert c.president.is_connected(pp) +@mark_sync_test def test_valid_replace(): brady = PersonWithRels(name="Tom Brady", age=40).save() assert brady @@ -174,17 +183,18 @@ def test_valid_replace(): brady.knows.connect(gronk) brady.knows.connect(colbert) - assert len(brady.knows) == 2 + assert len(brady.knows.all()) == 2 assert brady.knows.is_connected(gronk) assert brady.knows.is_connected(colbert) brady.knows.replace(hanks) - assert len(brady.knows) == 1 + assert len(brady.knows.all()) == 1 assert brady.knows.is_connected(hanks) assert not brady.knows.is_connected(gronk) assert not brady.knows.is_connected(colbert) +@mark_sync_test def test_props_relationship(): u = PersonWithRels(name="Mar", age=20).save() assert u diff --git a/test/test_relative_relationships.py b/test/sync_/test_relative_relationships.py similarity index 83% rename from test/test_relative_relationships.py rename to test/sync_/test_relative_relationships.py index db78d038..a01e28f9 100644 --- a/test/test_relative_relationships.py +++ b/test/sync_/test_relative_relationships.py @@ -1,6 +1,7 @@ -from neomodel import RelationshipTo, StringProperty, StructuredNode +from test._async_compat import mark_sync_test +from test.sync_.test_relationships import Country -from .test_relationships import Country +from neomodel import RelationshipTo, StringProperty, StructuredNode class Cat(StructuredNode): @@ -9,6 +10,7 @@ class Cat(StructuredNode): is_from = RelationshipTo(".test_relationships.Country", "IS_FROM") +@mark_sync_test def test_relative_relationship(): a = Cat(name="snufkin").save() assert a diff --git a/test/test_transactions.py b/test/sync_/test_transactions.py similarity index 82% rename from test/test_transactions.py rename to test/sync_/test_transactions.py index 0481e2a7..834b538e 100644 --- a/test/test_transactions.py +++ b/test/sync_/test_transactions.py @@ -1,22 +1,18 @@ +from test._async_compat import mark_sync_test + import pytest from neo4j.api import Bookmarks from neo4j.exceptions import ClientError, TransactionError from pytest import raises -from neomodel import ( - StringProperty, - StructuredNode, - UniqueProperty, - config, - db, - install_labels, -) +from neomodel import StringProperty, StructuredNode, UniqueProperty, db class APerson(StructuredNode): name = StringProperty(unique_index=True) +@mark_sync_test def test_rollback_and_commit_transaction(): for p in APerson.nodes: p.delete() @@ -42,14 +38,14 @@ def in_a_tx(*names): APerson(name=n).save() +@mark_sync_test def test_transaction_decorator(): - install_labels(APerson) + db.install_labels(APerson) for p in APerson.nodes: p.delete() # should work in_a_tx("Roger") - assert True # should bail but raise correct error with raises(UniqueProperty): @@ -58,6 +54,7 @@ def test_transaction_decorator(): assert "Jim" not in [p.name for p in APerson.nodes] +@mark_sync_test def test_transaction_as_a_context(): with db.transaction: APerson(name="Tim").save() @@ -69,6 +66,7 @@ def test_transaction_as_a_context(): APerson(name="Tim").save() +@mark_sync_test def test_query_inside_transaction(): for p in APerson.nodes: p.delete() @@ -80,11 +78,12 @@ def test_query_inside_transaction(): assert len([p.name for p in APerson.nodes]) == 2 +@mark_sync_test def test_read_transaction(): APerson(name="Johnny").save() with db.read_transaction: - people = APerson.nodes.all() + people = APerson.nodes assert people with raises(TransactionError): @@ -94,6 +93,7 @@ def test_read_transaction(): assert e.value.code == "Neo.ClientError.Statement.AccessMode" +@mark_sync_test def test_write_transaction(): with db.write_transaction: APerson(name="Amelia").save() @@ -102,6 +102,7 @@ def test_write_transaction(): assert amelia +@mark_sync_test def double_transaction(): db.begin() with raises(SystemError, match=r"Transaction in progress"): @@ -111,28 +112,30 @@ def double_transaction(): @db.transaction.with_bookmark -def in_a_tx(*names): +def in_a_tx_with_bookmark(*names): for n in names: APerson(name=n).save() -def test_bookmark_transaction_decorator(skip_neo4j_before_330): +@mark_sync_test +def test_bookmark_transaction_decorator(): for p in APerson.nodes: p.delete() # should work - result, bookmarks = in_a_tx("Ruth", bookmarks=None) + result, bookmarks = in_a_tx_with_bookmark("Ruth", bookmarks=None) assert result is None assert isinstance(bookmarks, Bookmarks) # should bail but raise correct error with raises(UniqueProperty): - in_a_tx("Jane", "Ruth") + in_a_tx_with_bookmark("Jane", "Ruth") assert "Jane" not in [p.name for p in APerson.nodes] -def test_bookmark_transaction_as_a_context(skip_neo4j_before_330): +@mark_sync_test +def test_bookmark_transaction_as_a_context(): with db.transaction as transaction: APerson(name="Tanya").save() assert isinstance(transaction.last_bookmark, Bookmarks) @@ -158,12 +161,13 @@ def begin_spy(*args, **kwargs): return spy_calls -def test_bookmark_passed_in_to_context(skip_neo4j_before_330, spy_on_db_begin): +@mark_sync_test +def test_bookmark_passed_in_to_context(spy_on_db_begin): transaction = db.transaction with transaction: pass - assert spy_on_db_begin[-1] == ((), {"access_mode": None, "bookmarks": None}) + assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None}) last_bookmark = transaction.last_bookmark transaction.bookmarks = last_bookmark @@ -175,7 +179,8 @@ def test_bookmark_passed_in_to_context(skip_neo4j_before_330, spy_on_db_begin): ) -def test_query_inside_bookmark_transaction(skip_neo4j_before_330): +@mark_sync_test +def test_query_inside_bookmark_transaction(): for p in APerson.nodes: p.delete() diff --git a/test/test_scripts.py b/test/test_scripts.py index e30fd213..5583af47 100644 --- a/test/test_scripts.py +++ b/test/test_scripts.py @@ -8,10 +8,8 @@ StructuredNode, StructuredRel, config, - db, - install_labels, - util, ) +from neomodel.sync_.core import db class ScriptsTestRel(StructuredRel): @@ -108,9 +106,9 @@ def test_neomodel_inspect_database(script_flavour): assert "usage: neomodel_inspect_database" in result.stdout assert result.returncode == 0 - util.clear_neo4j_database(db) - install_labels(ScriptsTestNode) - install_labels(ScriptsTestRel) + db.clear_neo4j_database() + db.install_labels(ScriptsTestNode) + db.install_labels(ScriptsTestRel) # Create a few nodes and a rel, with indexes and constraints node1 = ScriptsTestNode(personal_id="1", name="test").save()