diff --git a/.gitignore b/.gitignore
index a79c277d..562969f3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,7 +15,6 @@ development.env
.ropeproject
\#*\#
.eggs
-bin
lib
.vscode
pyvenv.cfg
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 81b274db..22db6a97 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,17 +1,18 @@
repos:
- - repo: https://github.com/psf/black
- rev: 22.3.0
- hooks:
- - id: black
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- # - repo: local
- # hooks:
- # - id: pylint
- # name: pylint
- # entry: pylint neomodel/
- # language: system
- # always_run: true
- # pass_filenames: false
\ No newline at end of file
+ args: ["--profile", "black"]
+ - repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+ - repo: local
+ hooks:
+ - id: unasync
+ name: unasync
+ entry: bin/make-unasync
+ language: python
+ files: "^(neomodel/async_|test/async_)/.*"
+ additional_dependencies: [unasync, isort, black]
\ No newline at end of file
diff --git a/README.md b/README.md
index 5d4566a5..0bd3478e 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ GitHub repo found at .
# Documentation
-(Needs an update, but) Available on
+Available on
[readthedocs](http://neomodel.readthedocs.org).
# Upcoming breaking changes notice - \>=5.3
@@ -47,7 +47,7 @@ support for Python 3.12.
Another source of upcoming breaking changes is the addition async support to
neomodel. No date is set yet, but the work has progressed a lot in the past weeks ;
-and it will be part of a major release (potentially 6.0 to avoid misunderstandings).
+and it will be part of a major release.
You can see the progress in [this branch](https://github.com/neo4j-contrib/neomodel/tree/task/async).
Finally, we are looking at refactoring some standalone methods into the
@@ -67,6 +67,15 @@ To install from github:
$ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev
+# Performance comparison
+
+You can find some performance tests made using Locust [in this repo](https://github.com/mariusconjeaud/neomodel-locust).
+
+Two learnings from this :
+
+* The wrapping of the driver made by neomodel is very thin performance-wise : it does not add a lot of overhead ;
+* When used in a concurrent fashion, async neomodel is faster than concurrent sync neomodel, and a lot of faster than serial queries.
+
# Contributing
Ideas, bugs, tests and pull requests always welcome. Please use
@@ -112,3 +121,42 @@ against all supported Python interpreters and neo4j versions: :
# in the project's root folder:
$ sh ./tests-with-docker-compose.sh
+
+## Developing with async
+
+### Transpiling async -> sync
+
+We use [this great library](https://github.com/python-trio/unasync) to automatically transpile async code into its sync version.
+
+In other words, when contributing to neomodel, only update the `async` code in `neomodel/async_`, then run : :
+
+ bin/make-unasync
+ isort .
+ black .
+
+Note that you can also use the pre-commit hooks for this.
+
+### Specific async/sync code
+This transpiling script mainly does two things :
+
+- It removes the await keywords, and the Async prefixes in class names
+- It does some specific replacements, like `adb`->`db`, `mark_async_test`->`mark_sync_test`
+
+It might be that your code should only be run for `async`, or `sync` ; or you want different stubs to be run for `async` vs `sync`.
+You can use the following utility function for this - taken from the official [Neo4j python driver code](https://github.com/neo4j/neo4j-python-driver) :
+
+ # neomodel/async_/core.py
+ from neomodel._async_compat.util import AsyncUtil
+
+ # AsyncUtil.is_async_code is always True
+ if AsyncUtil.is_async_code:
+ # Specific async code
+ # This one gets run when in async mode
+ assert await Coffee.nodes.check_contains(2)
+ else:
+ # Specific sync code
+ # This one gest run when in sync mode
+ assert 2 in Coffee.nodes
+
+You can check [test_match_api](test/async_/test_match_api.py) for some good examples, and how it's transpiled into sync.
+
diff --git a/bin/make-unasync b/bin/make-unasync
new file mode 100755
index 00000000..96375526
--- /dev/null
+++ b/bin/make-unasync
@@ -0,0 +1,351 @@
+#!/usr/bin/env python3
+
+import collections
+import errno
+import os
+import re
+import sys
+import tokenize as std_tokenize
+from pathlib import Path
+
+import black
+import isort
+import isort.files
+import unasync
+
+ROOT_DIR = Path(__file__).parents[1].absolute()
+ASYNC_DIR = ROOT_DIR / "neomodel" / "async_"
+SYNC_DIR = ROOT_DIR / "neomodel" / "sync_"
+ASYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "async_"
+SYNC_CONTRIB_DIR = ROOT_DIR / "neomodel" / "contrib" / "sync_"
+ASYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "async_"
+SYNC_INTEGRATION_TEST_DIR = ROOT_DIR / "test" / "sync_"
+INTEGRATION_TEST_EXCLUSION_LIST = ["conftest.py"]
+UNASYNC_SUFFIX = ".unasync"
+
+PY_FILE_EXTENSIONS = {".py"}
+
+
+# copy from unasync for customization -----------------------------------------
+# https://github.com/python-trio/unasync
+# License: MIT or Apache2
+
+
+Token = collections.namedtuple("Token", ["type", "string", "start", "end", "line"])
+
+
+def _makedirs_existok(dir):
+ try:
+ os.makedirs(dir)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def _get_tokens(f):
+ if sys.version_info[0] == 2:
+ for tok in std_tokenize.generate_tokens(f.readline):
+ type_, string, start, end, line = tok
+ yield Token(type_, string, start, end, line)
+ else:
+ for tok in std_tokenize.tokenize(f.readline):
+ if tok.type == std_tokenize.ENCODING:
+ continue
+ yield tok
+
+
+def _tokenize(f):
+ last_end = (1, 0)
+ for tok in _get_tokens(f):
+ if last_end[0] < tok.start[0]:
+ yield "", std_tokenize.STRING, " \\\n"
+ last_end = (tok.start[0], 0)
+
+ space = ""
+ if tok.start > last_end:
+ assert tok.start[0] == last_end[0]
+ space = " " * (tok.start[1] - last_end[1])
+ yield space, tok.type, tok.string
+
+ last_end = tok.end
+ if tok.type in [std_tokenize.NEWLINE, std_tokenize.NL]:
+ last_end = (tok.end[0] + 1, 0)
+ elif sys.version_info >= (3, 12) and tok.type == std_tokenize.FSTRING_MIDDLE:
+ last_end = (
+ last_end[0],
+ last_end[1] + tok.string.count("{") + tok.string.count("}"),
+ )
+
+
+def _untokenize(tokens):
+ return "".join(space + tokval for space, tokval in tokens)
+
+
+# end of copy -----------------------------------------------------------------
+
+
+class CustomRule(unasync.Rule):
+ def __init__(self, *args, **kwargs):
+ super(CustomRule, self).__init__(*args, **kwargs)
+ self.out_files = []
+
+ def _unasync_tokens(self, tokens):
+ # copy from unasync to fix handling of multiline strings
+ # https://github.com/python-trio/unasync
+ # License: MIT or Apache2
+
+ used_space = None
+ for space, toknum, tokval in tokens:
+ if tokval in ["async", "await"]:
+ # When removing async or await, we want to use the whitespace
+ # that was before async/await before the next token so that
+ # `print(await stuff)` becomes `print(stuff)` and not
+ # `print( stuff)`
+ used_space = space
+ else:
+ if toknum == std_tokenize.NAME:
+ tokval = self._unasync_name(tokval)
+ elif toknum == std_tokenize.STRING:
+ if tokval[0] == tokval[1] and len(tokval) > 2:
+ # multiline string (`"""..."""` or `'''...'''`)
+ left_quote, name, right_quote = (
+ tokval[:3],
+ tokval[3:-3],
+ tokval[-3:],
+ )
+ else:
+ # simple string (`"..."` or `'...'`)
+ left_quote, name, right_quote = (
+ tokval[:1],
+ tokval[1:-1],
+ tokval[-1:],
+ )
+ tokval = left_quote + self._unasync_string(name) + right_quote
+ elif (
+ sys.version_info >= (3, 12)
+ and toknum == std_tokenize.FSTRING_MIDDLE
+ ):
+ tokval = tokval.replace("{", "{{").replace("}", "}}")
+ tokval = self._unasync_string(tokval)
+ if used_space is None:
+ used_space = space
+ yield (used_space, tokval)
+ used_space = None
+
+ def _unasync_string(self, name):
+ start = 0
+ end = 1
+ out = ""
+ while end < len(name):
+ sub_name = name[start:end]
+ if sub_name.isidentifier():
+ end += 1
+ else:
+ if end == start + 1:
+ out += sub_name
+ start += 1
+ end += 1
+ else:
+ out += self._unasync_name(name[start : (end - 1)])
+ start = end - 1
+
+ sub_name = name[start:]
+ if sub_name.isidentifier():
+ out += self._unasync_name(name[start:])
+ else:
+ out += sub_name
+
+ # very boiled down unasync version that removes "async" and "await"
+ # substrings.
+ out = re.subn(
+ r"(^|\s+|(?<=\W))(?:async|await)\s+", r"\1", out, flags=re.MULTILINE
+ )[0]
+ # Convert doc-reference names from 'async-xyz' to 'xyz'
+ out = re.subn(r":ref:`async-", ":ref:`", out)[0]
+ return out
+
+ def _unasync_prefix(self, name):
+ # Convert class names from 'AsyncXyz' to 'Xyz'
+ if len(name) > 5 and name.startswith("Async") and name[5].isupper():
+ return name[5:]
+ # Convert class names from '_AsyncXyz' to '_Xyz'
+ elif len(name) > 6 and name.startswith("_Async") and name[6].isupper():
+ return "_" + name[6:]
+ # Convert variable/method/function names from 'async_xyz' to 'xyz'
+ elif len(name) > 6 and name.startswith("async_"):
+ return name[6:]
+ return name
+
+ def _unasync_name(self, name):
+ # copy from unasync to customize renaming rules
+ # https://github.com/python-trio/unasync
+ # License: MIT or Apache2
+ if name in self.token_replacements:
+ return self.token_replacements[name]
+ return self._unasync_prefix(name)
+
+ def _unasync_file(self, filepath):
+ # copy from unasync to append file suffix to out path
+ # https://github.com/python-trio/unasync
+ # License: MIT or Apache2
+ with open(filepath, "rb") as f:
+ write_kwargs = {}
+ if sys.version_info[0] >= 3:
+ encoding, _ = std_tokenize.detect_encoding(f.readline)
+ write_kwargs["encoding"] = encoding
+ f.seek(0)
+ tokens = _tokenize(f)
+ tokens = self._unasync_tokens(tokens)
+ result = _untokenize(tokens)
+ outfile_path = filepath.replace(self.fromdir, self.todir)
+ outfile_path += UNASYNC_SUFFIX
+ self.out_files.append(outfile_path)
+ _makedirs_existok(os.path.dirname(outfile_path))
+ with open(outfile_path, "w", **write_kwargs) as f:
+ print(result, file=f, end="")
+
+
+def apply_unasync(files):
+ """Generate sync code from async code."""
+
+ additional_main_replacements = {
+ "adb": "db",
+ "async_": "sync_",
+ "check_bool": "__bool__",
+ "check_nonzero": "__nonzero__",
+ "check_contains": "__contains__",
+ "get_item": "__getitem__",
+ "get_len": "__len__",
+ }
+ additional_test_replacements = {
+ "async_": "sync_",
+ "check_bool": "__bool__",
+ "check_nonzero": "__nonzero__",
+ "check_contains": "__contains__",
+ "get_item": "__getitem__",
+ "get_len": "__len__",
+ "adb": "db",
+ "mark_async_test": "mark_sync_test",
+ "mark_async_session_auto_fixture": "mark_sync_session_auto_fixture",
+ }
+ rules = [
+ CustomRule(
+ fromdir=str(ASYNC_DIR),
+ todir=str(SYNC_DIR),
+ additional_replacements=additional_main_replacements,
+ ),
+ CustomRule(
+ fromdir=str(ASYNC_CONTRIB_DIR),
+ todir=str(SYNC_CONTRIB_DIR),
+ additional_replacements=additional_main_replacements,
+ ),
+ CustomRule(
+ fromdir=str(ASYNC_INTEGRATION_TEST_DIR),
+ todir=str(SYNC_INTEGRATION_TEST_DIR),
+ additional_replacements=additional_test_replacements,
+ ),
+ ]
+
+ if not files:
+ paths = list(ASYNC_DIR.rglob("*"))
+ paths += list(ASYNC_CONTRIB_DIR.rglob("*"))
+ paths += [
+ path
+ for path in ASYNC_INTEGRATION_TEST_DIR.rglob("*")
+ if path.name not in INTEGRATION_TEST_EXCLUSION_LIST
+ ]
+ else:
+ paths = [ROOT_DIR / Path(f) for f in files]
+ filtered_paths = []
+ for path in paths:
+ if path.suffix in PY_FILE_EXTENSIONS:
+ filtered_paths.append(path)
+
+ unasync.unasync_files(map(str, filtered_paths), rules)
+
+ return [Path(path) for rule in rules for path in rule.out_files]
+
+
+def apply_black(paths):
+ """Prettify generated sync code.
+
+ Since keywords are removed, black might expect a different result,
+ especially line breaks.
+ """
+ for path in paths:
+ with open(path, "r") as file:
+ code = file.read()
+
+ formatted_code = black.format_str(code, mode=black.FileMode())
+
+ with open(path, "w") as file:
+ file.write(formatted_code)
+
+ return paths
+
+
+def apply_isort(paths):
+ """Sort imports in generated sync code.
+
+ Since classes in imports are renamed from AsyncXyz to Xyz, the alphabetical
+ order of the import can change.
+ """
+ isort_config = isort.Config(
+ settings_path=str(ROOT_DIR), quiet=True, profile="black"
+ )
+
+ for path in paths:
+ isort.file(str(path), config=isort_config)
+
+ return paths
+
+
+def apply_changes(paths):
+ def files_equal(path1, path2):
+ with open(path1, "rb") as f1:
+ with open(path2, "rb") as f2:
+ data1 = f1.read(1024)
+ data2 = f2.read(1024)
+ while data1 or data2:
+ if data1 != data2:
+ changed_paths[path1] = path2
+ return False
+ data1 = f1.read(1024)
+ data2 = f2.read(1024)
+ return True
+
+ changed_paths = {}
+
+ for in_path in paths:
+ out_path = Path(str(in_path)[: -len(UNASYNC_SUFFIX)])
+ if not out_path.is_file():
+ changed_paths[in_path] = out_path
+ continue
+ if not files_equal(in_path, out_path):
+ changed_paths[in_path] = out_path
+ continue
+ in_path.unlink()
+
+ for in_path, out_path in changed_paths.items():
+ in_path.replace(out_path)
+
+ return list(changed_paths.values())
+
+
+def main():
+ files = None
+ if len(sys.argv) >= 1:
+ files = sys.argv[1:]
+ paths = apply_unasync(files)
+ paths = apply_isort(paths)
+ paths = apply_black(paths)
+ changed_paths = apply_changes(paths)
+
+ if changed_paths:
+ for path in changed_paths:
+ print("Updated " + str(path))
+ exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/doc/source/configuration.rst b/doc/source/configuration.rst
index ed1af11b..3946ec98 100644
--- a/doc/source/configuration.rst
+++ b/doc/source/configuration.rst
@@ -59,7 +59,7 @@ Note that you have to manage the driver's lifecycle yourself.
However, everything else is still handled by neomodel : sessions, transactions, etc...
-NB : Only the synchronous driver will work in this way. The asynchronous driver is not supported yet.
+NB : Only the synchronous driver will work in this way. See the next section for the preferred method, and how to pass an async driver instance.
Change/Close the connection
---------------------------
@@ -119,14 +119,7 @@ with something like: ::
Enable automatic index and constraint creation
----------------------------------------------
-After the definition of a `StructuredNode`, Neomodel can install the corresponding
-constraints and indexes at compile time. However this method is only recommended for testing::
-
- from neomodel import config
- # before loading your node definitions
- config.AUTO_INSTALL_LABELS = True
-
-Neomodel also provides the :ref:`neomodel_install_labels` script for this task,
+Neomodel provides the :ref:`neomodel_install_labels` script for this task,
however if you want to handle this manually see below.
Install indexes and constraints for a single class::
@@ -146,6 +139,9 @@ Or for an entire 'schema' ::
# + Creating unique constraint for name on label User for class yourapp.models.User
# ...
+.. note::
+ config.AUTO_INSTALL_LABELS has been removed from neomodel in version 5.3
+
Require timezones on DateTimeProperty
-------------------------------------
diff --git a/doc/source/extending.rst b/doc/source/extending.rst
index 32067009..df50f84d 100644
--- a/doc/source/extending.rst
+++ b/doc/source/extending.rst
@@ -39,7 +39,7 @@ labels, the `__optional_labels__` property must be defined as a list of strings:
__optional_labels__ = ["SuperSaver", "SeniorDiscount"]
balance = IntegerProperty(index=True)
-.. warning:: The size of the node class mapping grows exponentially with optional labels. Use with some caution.
+.. note:: The size of the node class mapping grows exponentially with optional labels. Use with some caution.
Mixins
diff --git a/doc/source/getting_started.rst b/doc/source/getting_started.rst
index 25e999e7..c756a6f6 100644
--- a/doc/source/getting_started.rst
+++ b/doc/source/getting_started.rst
@@ -22,6 +22,7 @@ Querying the graph
neomodel is mainly used as an OGM (see next section), but you can also use it for direct Cypher queries : ::
+ from neomodel import db
results, meta = db.cypher_query("RETURN 'Hello World' as message")
@@ -104,7 +105,7 @@ and cardinality will be default (ZeroOrMore).
Finally, relationship cardinality is guessed from the database by looking at existing relationships, so it might
guess wrong on edge cases.
-.. warning::
+.. note::
The script relies on the method apoc.meta.cypher.types to parse property types. So APOC must be installed on your Neo4j server
for this script to work.
@@ -246,7 +247,7 @@ the following syntax::
Person.nodes.all().fetch_relations('city__country', Optional('country'))
-.. warning::
+.. note::
This feature is still a work in progress for extending path traversal and fecthing.
It currently stops at returning the resolved objects as they are returned in Cypher.
@@ -256,3 +257,60 @@ the following syntax::
If you want to go further in the resolution process, you have to develop your own parser (for now).
+
+Async neomodel
+==============
+
+neomodel supports asynchronous operations using the async support of neo4j driver. The examples below take a few of the above examples,
+but rewritten for async::
+
+ from neomodel import adb
+ results, meta = await adb.cypher_query("RETURN 'Hello World' as message")
+
+OGM with async ::
+
+ # Note that properties do not change, but nodes and relationships now have an Async prefix
+ from neomodel import (AsyncStructuredNode, StringProperty, IntegerProperty,
+ UniqueIdProperty, AsyncRelationshipTo)
+
+ class Country(AsyncStructuredNode):
+ code = StringProperty(unique_index=True, required=True)
+
+ class City(AsyncStructuredNode):
+ name = StringProperty(required=True)
+ country = AsyncRelationshipTo(Country, 'FROM_COUNTRY')
+
+ # Operations that interact with the database are now async
+ # Return all nodes
+ # Note that the nodes object is awaitable as is
+ all_nodes = await Country.nodes
+
+ # Relationships
+ germany = await Country(code='DE').save()
+ await jim.country.connect(germany)
+
+Most _dunder_ methods for nodes and relationships had to be overriden to support async operations. The following methods are supported ::
+
+ # Examples below are taken from the various tests. Please check them for more examples.
+ # Length
+ dogs_bonanza = await Dog.nodes.get_len()
+ # Sync equivalent - __len__
+ dogs_bonanza = len(Dog.nodes)
+ # Note that len(Dog.nodes) is more efficient than Dog.nodes.__len__
+
+ # Existence
+ assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool()
+ # Sync equivalent - __bool__
+ assert not Customer.nodes.filter(email="jim7@aol.com")
+ # Also works for check_nonzero => __nonzero__
+
+ # Contains
+ assert await Coffee.nodes.check_contains(aCoffeeNode)
+ # Sync equivalent - __contains__
+ assert aCoffeeNode in Coffee.nodes
+
+ # Get item
+ assert len(list((await Coffee.nodes)[1:])) == 2
+ # Sync equivalent - __getitem__
+ assert len(list(Coffee.nodes[1:])) == 2
+
diff --git a/doc/source/index.rst b/doc/source/index.rst
index ec3372c9..1338ce27 100644
--- a/doc/source/index.rst
+++ b/doc/source/index.rst
@@ -9,6 +9,7 @@ An Object Graph Mapper (OGM) for the Neo4j_ graph database, built on the awesome
- Enforce your schema through cardinality restrictions.
- Full transaction support.
- Thread safe.
+- Async support.
- pre/post save/delete hooks.
- Django integration via django_neomodel_
@@ -40,6 +41,26 @@ To install from github::
$ pip install git+git://github.com/neo4j-contrib/neomodel.git@HEAD#egg=neomodel-dev
+.. note::
+
+ **Breaking changes in 5.3**
+
+ Introducing support for asynchronous programming to neomodel required to introduce some breaking changes:
+
+ - config.AUTO_INSTALL_LABELS has been removed. Please use the `neomodel_install_labels` (:ref:`neomodel_install_labels`) command instead.
+
+ **Deprecations in 5.3**
+
+ - Some standalone methods are moved into the Database() class and will be removed in a future release :
+ - change_neo4j_password
+ - clear_neo4j_database
+ - drop_constraints
+ - drop_indexes
+ - remove_all_labels
+ - install_labels
+ - install_all_labels
+ - Additionally, to call these methods with async, use the ones in the AsyncDatabase() _adb_ singleton.
+
Contents
========
@@ -59,6 +80,8 @@ Contents
configuration
extending
module_documentation
+ module_documentation_sync
+ module_documentation_async
Indices and tables
==================
diff --git a/doc/source/module_documentation.rst b/doc/source/module_documentation.rst
index 16a4acf2..364e207e 100644
--- a/doc/source/module_documentation.rst
+++ b/doc/source/module_documentation.rst
@@ -1,24 +1,6 @@
-=====================
-Modules documentation
-=====================
-
-Database
-========
-.. module:: neomodel.util
-.. autoclass:: neomodel.util.Database
- :members:
- :undoc-members:
-
-Core
-====
-.. automodule:: neomodel.core
- :members:
-
-.. _semistructurednode_doc:
-
-``SemiStructuredNode``
-----------------------
-.. autoclass:: neomodel.contrib.SemiStructuredNode
+============
+General API
+============
Properties
==========
@@ -32,43 +14,6 @@ Spatial Properties & Datatypes
:members:
:show-inheritance:
-Relationships
-=============
-.. automodule:: neomodel.relationship
- :members:
- :show-inheritance:
-
-.. automodule:: neomodel.relationship_manager
- :members:
- :show-inheritance:
-
-.. automodule:: neomodel.cardinality
- :members:
- :show-inheritance:
-
-Paths
-=====
-
-.. automodule:: neomodel.path
- :members:
- :show-inheritance:
-
-
-
-
-Match
-=====
-.. module:: neomodel.match
-.. autoclass:: neomodel.match.BaseSet
- :members:
- :undoc-members:
-.. autoclass:: neomodel.match.NodeSet
- :members:
- :undoc-members:
-.. autoclass:: neomodel.match.Traversal
- :members:
- :undoc-members:
-
Exceptions
==========
@@ -78,16 +23,21 @@ Exceptions
:undoc-members:
:show-inheritance:
+
Scripts
=======
-.. automodule:: neomodel.scripts.neomodel_install_labels
+.. automodule:: neomodel.scripts.neomodel_inspect_database
:members:
:undoc-members:
:show-inheritance:
-.. automodule:: neomodel.scripts.neomodel_remove_labels
+.. automodule:: neomodel.scripts.neomodel_install_labels
:members:
:undoc-members:
:show-inheritance:
+.. automodule:: neomodel.scripts.neomodel_remove_labels
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/doc/source/module_documentation_async.rst b/doc/source/module_documentation_async.rst
new file mode 100644
index 00000000..2150b235
--- /dev/null
+++ b/doc/source/module_documentation_async.rst
@@ -0,0 +1,55 @@
+=======================
+Async API Documentation
+=======================
+
+Core
+====
+.. automodule:: neomodel.async_.core
+ :members:
+
+.. _semistructurednode_doc:
+
+``AsyncSemiStructuredNode``
+---------------------------
+.. autoclass:: neomodel.contrib.AsyncSemiStructuredNode
+
+Relationships
+=============
+.. automodule:: neomodel.async_.relationship
+ :members:
+ :show-inheritance:
+
+.. automodule:: neomodel.async_.relationship_manager
+ :members:
+ :show-inheritance:
+
+.. automodule:: neomodel.async_.cardinality
+ :members:
+ :show-inheritance:
+
+Property Manager
+================
+.. automodule:: neomodel.async_.property_manager
+ :members:
+ :show-inheritance:
+
+Paths
+=====
+
+.. automodule:: neomodel.async_.path
+ :members:
+ :show-inheritance:
+
+Match
+=====
+.. module:: neomodel.async_.match
+.. autoclass:: neomodel.async_.match.AsyncBaseSet
+ :members:
+ :undoc-members:
+.. autoclass:: neomodel.async_.match.AsyncNodeSet
+ :members:
+ :undoc-members:
+.. autoclass:: neomodel.async_.match.AsyncTraversal
+ :members:
+ :undoc-members:
+
diff --git a/doc/source/module_documentation_sync.rst b/doc/source/module_documentation_sync.rst
new file mode 100644
index 00000000..5214a89a
--- /dev/null
+++ b/doc/source/module_documentation_sync.rst
@@ -0,0 +1,55 @@
+======================
+Sync API Documentation
+======================
+
+Core
+====
+.. automodule:: neomodel.sync_.core
+ :members:
+
+.. _semistructurednode_doc:
+
+``SemiStructuredNode``
+---------------------------
+.. autoclass:: neomodel.contrib.SemiStructuredNode
+
+Relationships
+=============
+.. automodule:: neomodel.sync_.relationship
+ :members:
+ :show-inheritance:
+
+.. automodule:: neomodel.sync_.relationship_manager
+ :members:
+ :show-inheritance:
+
+.. automodule:: neomodel.sync_.cardinality
+ :members:
+ :show-inheritance:
+
+Property Manager
+================
+.. automodule:: neomodel.sync_.property_manager
+ :members:
+ :show-inheritance:
+
+Paths
+=====
+
+.. automodule:: neomodel.sync_.path
+ :members:
+ :show-inheritance:
+
+Match
+=====
+.. module:: neomodel.sync_.match
+.. autoclass:: neomodel.sync_.match.BaseSet
+ :members:
+ :undoc-members:
+.. autoclass:: neomodel.sync_.match.NodeSet
+ :members:
+ :undoc-members:
+.. autoclass:: neomodel.sync_.match.Traversal
+ :members:
+ :undoc-members:
+
diff --git a/doc/source/queries.rst b/doc/source/queries.rst
index c3e37629..4c77a791 100644
--- a/doc/source/queries.rst
+++ b/doc/source/queries.rst
@@ -240,3 +240,19 @@ relationships to their relationship models *if such a model exists*. In other wo
relationships with data (such as ``PersonLivesInCity`` above) will be instantiated to their
respective objects or ``StrucuredRel`` otherwise. Relationships do not "reload" their
end-points (unless this is required).
+
+Async neomodel - Caveats
+========================
+
+Python does not support async dunder methods. This means that we had to implement some overrides for those.
+See the example below::
+
+ # This will not work as it uses the synchronous __bool__ method
+ assert await Customer.nodes.filter(prop="value")
+
+ # Do this instead
+ assert await Customer.nodes.filter(prop="value").check_bool()
+ assert await Customer.nodes.filter(prop="value").check_nonzero()
+
+ # Note : no changes are needed for sync so this still works :
+ assert Customer.nodes.filter(prop="value")
diff --git a/neomodel/__init__.py b/neomodel/__init__.py
index aee40a08..d7d0febb 100644
--- a/neomodel/__init__.py
+++ b/neomodel/__init__.py
@@ -1,20 +1,25 @@
# pep8: noqa
-
+from neomodel.async_.cardinality import (
+ AsyncOne,
+ AsyncOneOrMore,
+ AsyncZeroOrMore,
+ AsyncZeroOrOne,
+)
+from neomodel.async_.core import AsyncStructuredNode, adb
+from neomodel.async_.match import AsyncNodeSet, AsyncTraversal
+from neomodel.async_.path import AsyncNeomodelPath
+from neomodel.async_.property_manager import AsyncPropertyManager
+from neomodel.async_.relationship import AsyncStructuredRel
+from neomodel.async_.relationship_manager import (
+ AsyncRelationship,
+ AsyncRelationshipDefinition,
+ AsyncRelationshipFrom,
+ AsyncRelationshipManager,
+ AsyncRelationshipTo,
+)
from neomodel.exceptions import *
-from neomodel.match import EITHER, INCOMING, OUTGOING, NodeSet, Traversal
from neomodel.match_q import Q # noqa
-from neomodel.relationship_manager import (
- NotConnected,
- Relationship,
- RelationshipDefinition,
- RelationshipFrom,
- RelationshipManager,
- RelationshipTo,
-)
-
-from .cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne
-from .core import *
-from .properties import (
+from neomodel.properties import (
AliasProperty,
ArrayProperty,
BooleanProperty,
@@ -30,9 +35,30 @@
StringProperty,
UniqueIdProperty,
)
-from .relationship import StructuredRel
-from .util import change_neo4j_password, clear_neo4j_database
-from .path import NeomodelPath
+from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne
+from neomodel.sync_.core import (
+ StructuredNode,
+ change_neo4j_password,
+ clear_neo4j_database,
+ db,
+ drop_constraints,
+ drop_indexes,
+ install_all_labels,
+ install_labels,
+ remove_all_labels,
+)
+from neomodel.sync_.match import NodeSet, Traversal
+from neomodel.sync_.path import NeomodelPath
+from neomodel.sync_.property_manager import PropertyManager
+from neomodel.sync_.relationship import StructuredRel
+from neomodel.sync_.relationship_manager import (
+ Relationship,
+ RelationshipDefinition,
+ RelationshipFrom,
+ RelationshipManager,
+ RelationshipTo,
+)
+from neomodel.util import EITHER, INCOMING, OUTGOING
__author__ = "Robin Edwards"
__email__ = "robin.ge@gmail.com"
diff --git a/neomodel/_async_compat/util.py b/neomodel/_async_compat/util.py
new file mode 100644
index 00000000..4868c3ba
--- /dev/null
+++ b/neomodel/_async_compat/util.py
@@ -0,0 +1,9 @@
+import typing as t
+
+
+class AsyncUtil:
+ is_async_code: t.ClassVar = True
+
+
+class Util:
+ is_async_code: t.ClassVar = False
diff --git a/test/test_contrib/__init__.py b/neomodel/async_/__init__.py
similarity index 100%
rename from test/test_contrib/__init__.py
rename to neomodel/async_/__init__.py
diff --git a/neomodel/async_/cardinality.py b/neomodel/async_/cardinality.py
new file mode 100644
index 00000000..17101cec
--- /dev/null
+++ b/neomodel/async_/cardinality.py
@@ -0,0 +1,135 @@
+from neomodel.async_.relationship_manager import ( # pylint:disable=unused-import
+ AsyncRelationshipManager,
+ AsyncZeroOrMore,
+)
+from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation
+
+
+class AsyncZeroOrOne(AsyncRelationshipManager):
+ """A relationship to zero or one node."""
+
+ description = "zero or one relationship"
+
+ async def single(self):
+ """
+ Return the associated node.
+
+ :return: node
+ """
+ nodes = await super().all()
+ if len(nodes) == 1:
+ return nodes[0]
+ if len(nodes) > 1:
+ raise CardinalityViolation(self, len(nodes))
+ return None
+
+ async def all(self):
+ node = await self.single()
+ return [node] if node else []
+
+ async def connect(self, node, properties=None):
+ """
+ Connect to a node.
+
+ :param node:
+ :type: StructuredNode
+ :param properties: relationship properties
+ :type: dict
+ :return: True / rel instance
+ """
+ if await super().get_len():
+ raise AttemptedCardinalityViolation(
+ f"Node already has {self} can't connect more"
+ )
+ return await super().connect(node, properties)
+
+
+class AsyncOneOrMore(AsyncRelationshipManager):
+ """A relationship to zero or more nodes."""
+
+ description = "one or more relationships"
+
+ async def single(self):
+ """
+ Fetch one of the related nodes
+
+ :return: Node
+ """
+ nodes = await super().all()
+ if nodes:
+ return nodes[0]
+ raise CardinalityViolation(self, "none")
+
+ async def all(self):
+ """
+ Returns all related nodes.
+
+ :return: [node1, node2...]
+ """
+ nodes = await super().all()
+ if nodes:
+ return nodes
+ raise CardinalityViolation(self, "none")
+
+ async def disconnect(self, node):
+ """
+ Disconnect node
+ :param node:
+ :return:
+ """
+ if await super().get_len() < 2:
+ raise AttemptedCardinalityViolation("One or more expected")
+ return await super().disconnect(node)
+
+
+class AsyncOne(AsyncRelationshipManager):
+ """
+ A relationship to a single node
+ """
+
+ description = "one relationship"
+
+ async def single(self):
+ """
+ Return the associated node.
+
+ :return: node
+ """
+ nodes = await super().all()
+ if nodes:
+ if len(nodes) == 1:
+ return nodes[0]
+ raise CardinalityViolation(self, len(nodes))
+ raise CardinalityViolation(self, "none")
+
+ async def all(self):
+ """
+ Return single node in an array
+
+ :return: [node]
+ """
+ return [await self.single()]
+
+ async def disconnect(self, node):
+ raise AttemptedCardinalityViolation(
+ "Cardinality one, cannot disconnect use reconnect."
+ )
+
+ async def disconnect_all(self):
+ raise AttemptedCardinalityViolation(
+ "Cardinality one, cannot disconnect_all use reconnect."
+ )
+
+ async def connect(self, node, properties=None):
+ """
+ Connect a node
+
+ :param node:
+ :param properties: relationship properties
+ :return: True / rel instance
+ """
+ if not hasattr(self.source, "element_id") or self.source.element_id is None:
+ raise ValueError("Node has not been saved cannot connect!")
+ if await super().get_len():
+ raise AttemptedCardinalityViolation("Node already has one relationship")
+ return await super().connect(node, properties)
diff --git a/neomodel/async_/core.py b/neomodel/async_/core.py
new file mode 100644
index 00000000..21228cb9
--- /dev/null
+++ b/neomodel/async_/core.py
@@ -0,0 +1,1523 @@
+import logging
+import os
+import sys
+import time
+import warnings
+from asyncio import iscoroutinefunction
+from itertools import combinations
+from threading import local
+from typing import Optional, Sequence
+from urllib.parse import quote, unquote, urlparse
+
+from neo4j import (
+ DEFAULT_DATABASE,
+ AsyncDriver,
+ AsyncGraphDatabase,
+ AsyncResult,
+ AsyncSession,
+ AsyncTransaction,
+ basic_auth,
+)
+from neo4j.api import Bookmarks
+from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired
+from neo4j.graph import Node, Path, Relationship
+
+from neomodel import config
+from neomodel._async_compat.util import AsyncUtil
+from neomodel.async_.property_manager import AsyncPropertyManager
+from neomodel.exceptions import (
+ ConstraintValidationFailed,
+ DoesNotExist,
+ FeatureNotSupported,
+ NodeClassAlreadyDefined,
+ NodeClassNotDefined,
+ RelationshipClassNotDefined,
+ UniqueProperty,
+)
+from neomodel.hooks import hooks
+from neomodel.properties import Property
+from neomodel.util import (
+ _get_node_properties,
+ _UnsavedNode,
+ classproperty,
+ deprecated,
+ version_tag_to_integer,
+)
+
+logger = logging.getLogger(__name__)
+
+RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists"
+INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists"
+CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists"
+STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg"
+NOT_COROUTINE_ERROR = "The decorated function must be a coroutine"
+
+
+# make sure the connection url has been set prior to executing the wrapped function
+def ensure_connection(func):
+ """Decorator that ensures a connection is established before executing the decorated function.
+
+ Args:
+ func (callable): The function to be decorated.
+
+ Returns:
+ callable: The decorated function.
+
+ """
+
+ async def wrapper(self, *args, **kwargs):
+ # Sort out where to find url
+ if hasattr(self, "db"):
+ _db = self.db
+ else:
+ _db = self
+
+ if not _db.driver:
+ if hasattr(config, "DATABASE_URL") and config.DATABASE_URL:
+ await _db.set_connection(url=config.DATABASE_URL)
+ elif hasattr(config, "DRIVER") and config.DRIVER:
+ await _db.set_connection(driver=config.DRIVER)
+
+ return await func(self, *args, **kwargs)
+
+ return wrapper
+
+
+class AsyncDatabase(local):
+ """
+ A singleton object via which all operations from neomodel to the Neo4j backend are handled with.
+ """
+
+ _NODE_CLASS_REGISTRY = {}
+
+ def __init__(self):
+ self._active_transaction = None
+ self.url = None
+ self.driver = None
+ self._session = None
+ self._pid = None
+ self._database_name = DEFAULT_DATABASE
+ self.protocol_version = None
+ self._database_version = None
+ self._database_edition = None
+ self.impersonated_user = None
+
+ async def set_connection(self, url: str = None, driver: AsyncDriver = None):
+ """
+ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance.
+
+ Args:
+ url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname.
+ When provided, a Neo4j driver instance will be created by neomodel.
+
+ driver (neo4j.Driver): Optionally, a pre-created driver instance.
+ When provided, neomodel will not create a driver instance but use this one instead.
+ """
+ if driver:
+ self.driver = driver
+ if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME:
+ self._database_name = config.DATABASE_NAME
+ elif url:
+ self._parse_driver_from_url(url=url)
+
+ self._pid = os.getpid()
+ self._active_transaction = None
+ # Set to default database if it hasn't been set before
+ if self._database_name is None:
+ self._database_name = DEFAULT_DATABASE
+
+ # Getting the information about the database version requires a connection to the database
+ self._database_version = None
+ self._database_edition = None
+ await self._update_database_version()
+
+ def _parse_driver_from_url(self, url: str) -> None:
+ """Parse the driver information from the given URL and initialize the driver.
+
+ Args:
+ url (str): The URL to parse.
+
+ Raises:
+ ValueError: If the URL format is not as expected.
+
+ Returns:
+ None - Sets the driver and database_name as class properties
+ """
+ p_start = url.replace(":", "", 1).find(":") + 2
+ p_end = url.rfind("@")
+ password = url[p_start:p_end]
+ url = url.replace(password, quote(password))
+ parsed_url = urlparse(url)
+
+ valid_schemas = [
+ "bolt",
+ "bolt+s",
+ "bolt+ssc",
+ "bolt+routing",
+ "neo4j",
+ "neo4j+s",
+ "neo4j+ssc",
+ ]
+
+ if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas:
+ credentials, hostname = parsed_url.netloc.rsplit("@", 1)
+ username, password = credentials.split(":")
+ password = unquote(password)
+ database_name = parsed_url.path.strip("/")
+ else:
+ raise ValueError(
+ f"Expecting url format: bolt://user:password@localhost:7687 got {url}"
+ )
+
+ options = {
+ "auth": basic_auth(username, password),
+ "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT,
+ "connection_timeout": config.CONNECTION_TIMEOUT,
+ "keep_alive": config.KEEP_ALIVE,
+ "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME,
+ "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE,
+ "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME,
+ "resolver": config.RESOLVER,
+ "user_agent": config.USER_AGENT,
+ }
+
+ if "+s" not in parsed_url.scheme:
+ options["encrypted"] = config.ENCRYPTED
+ options["trusted_certificates"] = config.TRUSTED_CERTIFICATES
+
+ self.driver = AsyncGraphDatabase.driver(
+ parsed_url.scheme + "://" + hostname, **options
+ )
+ self.url = url
+ # The database name can be provided through the url or the config
+ if database_name == "":
+ if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME:
+ self._database_name = config.DATABASE_NAME
+ else:
+ self._database_name = database_name
+
+ async def close_connection(self):
+ """
+ Closes the currently open driver.
+ The driver should always be closed at the end of the application's lifecyle.
+ """
+ self._database_version = None
+ self._database_edition = None
+ self._database_name = None
+ await self.driver.close()
+ self.driver = None
+
+ @property
+ async def database_version(self):
+ if self._database_version is None:
+ await self._update_database_version()
+
+ return self._database_version
+
+ @property
+ async def database_edition(self):
+ if self._database_edition is None:
+ await self._update_database_version()
+
+ return self._database_edition
+
+ @property
+ def transaction(self):
+ """
+ Returns the current transaction object
+ """
+ return AsyncTransactionProxy(self)
+
+ @property
+ def write_transaction(self):
+ return AsyncTransactionProxy(self, access_mode="WRITE")
+
+ @property
+ def read_transaction(self):
+ return AsyncTransactionProxy(self, access_mode="READ")
+
+ async def impersonate(self, user: str) -> "ImpersonationHandler":
+ """All queries executed within this context manager will be executed as impersonated user
+
+ Args:
+ user (str): User to impersonate
+
+ Returns:
+ ImpersonationHandler: Context manager to set/unset the user to impersonate
+ """
+ db_edition = await self.database_edition
+ if db_edition != "enterprise":
+ raise FeatureNotSupported(
+ "Impersonation is only available in Neo4j Enterprise edition"
+ )
+ return ImpersonationHandler(self, impersonated_user=user)
+
+ @ensure_connection
+ async def begin(self, access_mode=None, **parameters):
+ """
+ Begins a new transaction. Raises SystemError if a transaction is already active.
+ """
+ if (
+ hasattr(self, "_active_transaction")
+ and self._active_transaction is not None
+ ):
+ raise SystemError("Transaction in progress")
+ self._session: AsyncSession = self.driver.session(
+ default_access_mode=access_mode,
+ database=self._database_name,
+ impersonated_user=self.impersonated_user,
+ **parameters,
+ )
+ self._active_transaction: AsyncTransaction = (
+ await self._session.begin_transaction()
+ )
+
+ @ensure_connection
+ async def commit(self):
+ """
+ Commits the current transaction and closes its session
+
+ :return: last_bookmarks
+ """
+ try:
+ await self._active_transaction.commit()
+ last_bookmarks: Bookmarks = await self._session.last_bookmarks()
+ finally:
+ # In case when something went wrong during
+ # committing changes to the database
+ # we have to close an active transaction and session.
+ await self._active_transaction.close()
+ await self._session.close()
+ self._active_transaction = None
+ self._session = None
+
+ return last_bookmarks
+
+ @ensure_connection
+ async def rollback(self):
+ """
+ Rolls back the current transaction and closes its session
+ """
+ try:
+ await self._active_transaction.rollback()
+ finally:
+ # In case when something went wrong during changes rollback,
+ # we have to close an active transaction and session
+ await self._active_transaction.close()
+ await self._session.close()
+ self._active_transaction = None
+ self._session = None
+
+ async def _update_database_version(self):
+ """
+ Updates the database server information when it is required
+ """
+ try:
+ results = await self.cypher_query(
+ "CALL dbms.components() yield versions, edition return versions[0], edition"
+ )
+ self._database_version = results[0][0][0]
+ self._database_edition = results[0][0][1]
+ except ServiceUnavailable:
+ # The database server is not running yet
+ pass
+
+ def _object_resolution(self, object_to_resolve):
+ """
+ Performs in place automatic object resolution on a result
+ returned by cypher_query.
+
+ The function operates recursively in order to be able to resolve Nodes
+ within nested list structures and Path objects. Not meant to be called
+ directly, used primarily by _result_resolution.
+
+ :param object_to_resolve: A result as returned by cypher_query.
+ :type Any:
+
+ :return: An instantiated object.
+ """
+ # Below is the original comment that came with the code extracted in
+ # this method. It is not very clear but I decided to keep it just in
+ # case
+ #
+ #
+ # For some reason, while the type of `a_result_attribute[1]`
+ # as reported by the neo4j driver is `Node` for Node-type data
+ # retrieved from the database.
+ # When the retrieved data are Relationship-Type,
+ # the returned type is `abc.[REL_LABEL]` which is however
+ # a descendant of Relationship.
+ # Consequently, the type checking was changed for both
+ # Node, Relationship objects
+ if isinstance(object_to_resolve, Node):
+ return self._NODE_CLASS_REGISTRY[
+ frozenset(object_to_resolve.labels)
+ ].inflate(object_to_resolve)
+
+ if isinstance(object_to_resolve, Relationship):
+ rel_type = frozenset([object_to_resolve.type])
+ return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve)
+
+ if isinstance(object_to_resolve, Path):
+ from neomodel.async_.path import AsyncNeomodelPath
+
+ return AsyncNeomodelPath(object_to_resolve)
+
+ if isinstance(object_to_resolve, list):
+ return self._result_resolution([object_to_resolve])
+
+ return object_to_resolve
+
+ def _result_resolution(self, result_list):
+ """
+ Performs in place automatic object resolution on a set of results
+ returned by cypher_query.
+
+ The function operates recursively in order to be able to resolve Nodes
+ within nested list structures. Not meant to be called directly,
+ used primarily by cypher_query.
+
+ :param result_list: A list of results as returned by cypher_query.
+ :type list:
+
+ :return: A list of instantiated objects.
+ """
+
+ # Object resolution occurs in-place
+ for a_result_item in enumerate(result_list):
+ for a_result_attribute in enumerate(a_result_item[1]):
+ try:
+ # Primitive types should remain primitive types,
+ # Nodes to be resolved to native objects
+ resolved_object = a_result_attribute[1]
+
+ resolved_object = self._object_resolution(resolved_object)
+
+ result_list[a_result_item[0]][
+ a_result_attribute[0]
+ ] = resolved_object
+
+ except KeyError as exc:
+ # Not being able to match the label set of a node with a known object results
+ # in a KeyError in the internal dictionary used for resolution. If it is impossible
+ # to match, then raise an exception with more details about the error.
+ if isinstance(a_result_attribute[1], Node):
+ raise NodeClassNotDefined(
+ a_result_attribute[1], self._NODE_CLASS_REGISTRY
+ ) from exc
+
+ if isinstance(a_result_attribute[1], Relationship):
+ raise RelationshipClassNotDefined(
+ a_result_attribute[1], self._NODE_CLASS_REGISTRY
+ ) from exc
+
+ return result_list
+
+ @ensure_connection
+ async def cypher_query(
+ self,
+ query,
+ params=None,
+ handle_unique=True,
+ retry_on_session_expire=False,
+ resolve_objects=False,
+ ):
+ """
+ Runs a query on the database and returns a list of results and their headers.
+
+ :param query: A CYPHER query
+ :type: str
+ :param params: Dictionary of parameters
+ :type: dict
+ :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors
+ :type: bool
+ :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired.
+ If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance.
+ :type: bool
+ :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically
+ :type: bool
+
+ :return: A tuple containing a list of results and a tuple of headers.
+ """
+
+ if self._active_transaction:
+ # Use current session is a transaction is currently active
+ results, meta = await self._run_cypher_query(
+ self._active_transaction,
+ query,
+ params,
+ handle_unique,
+ retry_on_session_expire,
+ resolve_objects,
+ )
+ else:
+ # Otherwise create a new session in a with to dispose of it after it has been run
+ async with self.driver.session(
+ database=self._database_name, impersonated_user=self.impersonated_user
+ ) as session:
+ results, meta = await self._run_cypher_query(
+ session,
+ query,
+ params,
+ handle_unique,
+ retry_on_session_expire,
+ resolve_objects,
+ )
+
+ return results, meta
+
+ async def _run_cypher_query(
+ self,
+ session: AsyncSession,
+ query,
+ params,
+ handle_unique,
+ retry_on_session_expire,
+ resolve_objects,
+ ):
+ try:
+ # Retrieve the data
+ start = time.time()
+ response: AsyncResult = await session.run(query, params)
+ results, meta = [list(r.values()) async for r in response], response.keys()
+ end = time.time()
+
+ if resolve_objects:
+ # Do any automatic resolution required
+ results = self._result_resolution(results)
+
+ except ClientError as e:
+ if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed":
+ if "already exists with label" in e.message and handle_unique:
+ raise UniqueProperty(e.message) from e
+
+ raise ConstraintValidationFailed(e.message) from e
+ exc_info = sys.exc_info()
+ raise exc_info[1].with_traceback(exc_info[2])
+ except SessionExpired:
+ if retry_on_session_expire:
+ await self.set_connection(url=self.url)
+ return await self.cypher_query(
+ query=query,
+ params=params,
+ handle_unique=handle_unique,
+ retry_on_session_expire=False,
+ )
+ raise
+
+ tte = end - start
+ if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float(
+ os.environ.get("NEOMODEL_SLOW_QUERIES", 0)
+ ):
+ logger.debug(
+ "query: "
+ + query
+ + "\nparams: "
+ + repr(params)
+ + f"\ntook: {tte:.2g}s\n"
+ )
+
+ return results, meta
+
+ async def get_id_method(self) -> str:
+ db_version = await self.database_version
+ if db_version.startswith("4"):
+ return "id"
+ else:
+ return "elementId"
+
+ async def parse_element_id(self, element_id: str):
+ db_version = await self.database_version
+ return int(element_id) if db_version.startswith("4") else element_id
+
+ async def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]:
+ """Returns all indexes existing in the database
+
+ Arguments:
+ exclude_token_lookup[bool]: Exclude automatically create token lookup indexes
+
+ Returns:
+ Sequence[dict]: List of dictionaries, each entry being an index definition
+ """
+ indexes, meta_indexes = await self.cypher_query("SHOW INDEXES")
+ indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes]
+
+ if exclude_token_lookup:
+ indexes_as_dict = [
+ obj for obj in indexes_as_dict if obj["type"] != "LOOKUP"
+ ]
+
+ return indexes_as_dict
+
+ async def list_constraints(self) -> Sequence[dict]:
+ """Returns all constraints existing in the database
+
+ Returns:
+ Sequence[dict]: List of dictionaries, each entry being a constraint definition
+ """
+ constraints, meta_constraints = await self.cypher_query("SHOW CONSTRAINTS")
+ constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints]
+
+ return constraints_as_dict
+
+ @ensure_connection
+ async def version_is_higher_than(self, version_tag: str) -> bool:
+ """Returns true if the database version is higher or equal to a given tag
+
+ Args:
+ version_tag (str): The version to compare against
+
+ Returns:
+ bool: True if the database version is higher or equal to the given version
+ """
+ db_version = await self.database_version
+ return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag)
+
+ @ensure_connection
+ async def edition_is_enterprise(self) -> bool:
+ """Returns true if the database edition is enterprise
+
+ Returns:
+ bool: True if the database edition is enterprise
+ """
+ edition = await self.database_edition
+ return edition == "enterprise"
+
+ async def change_neo4j_password(self, user, new_password):
+ await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'")
+
+ async def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False):
+ await self.cypher_query(
+ """
+ MATCH (a)
+ CALL { WITH a DETACH DELETE a }
+ IN TRANSACTIONS OF 5000 rows
+ """
+ )
+ if clear_constraints:
+ await drop_constraints()
+ if clear_indexes:
+ await drop_indexes()
+
+ async def drop_constraints(self, quiet=True, stdout=None):
+ """
+ Discover and drop all constraints.
+
+ :type: bool
+ :return: None
+ """
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ results, meta = await self.cypher_query("SHOW CONSTRAINTS")
+
+ results_as_dict = [dict(zip(meta, row)) for row in results]
+ for constraint in results_as_dict:
+ await self.cypher_query("DROP CONSTRAINT " + constraint["name"])
+ if not quiet:
+ stdout.write(
+ (
+ " - Dropping unique constraint and index"
+ f" on label {constraint['labelsOrTypes'][0]}"
+ f" with property {constraint['properties'][0]}.\n"
+ )
+ )
+ if not quiet:
+ stdout.write("\n")
+
+ async def drop_indexes(self, quiet=True, stdout=None):
+ """
+ Discover and drop all indexes, except the automatically created token lookup indexes.
+
+ :type: bool
+ :return: None
+ """
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ indexes = await self.list_indexes(exclude_token_lookup=True)
+ for index in indexes:
+ await self.cypher_query("DROP INDEX " + index["name"])
+ if not quiet:
+ stdout.write(
+ f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n'
+ )
+ if not quiet:
+ stdout.write("\n")
+
+ async def remove_all_labels(self, stdout=None):
+ """
+ Calls functions for dropping constraints and indexes.
+
+ :param stdout: output stream
+ :return: None
+ """
+
+ if not stdout:
+ stdout = sys.stdout
+
+ stdout.write("Dropping constraints...\n")
+ await self.drop_constraints(quiet=False, stdout=stdout)
+
+ stdout.write("Dropping indexes...\n")
+ await self.drop_indexes(quiet=False, stdout=stdout)
+
+ async def install_all_labels(self, stdout=None):
+ """
+ Discover all subclasses of StructuredNode in your application and execute install_labels on each.
+ Note: code must be loaded (imported) in order for a class to be discovered.
+
+ :param stdout: output stream
+ :return: None
+ """
+
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ def subsub(cls): # recursively return all subclasses
+ subclasses = cls.__subclasses__()
+ if not subclasses: # base case: no more subclasses
+ return []
+ return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)]
+
+ stdout.write("Setting up indexes and constraints...\n\n")
+
+ i = 0
+ for cls in subsub(AsyncStructuredNode):
+ stdout.write(f"Found {cls.__module__}.{cls.__name__}\n")
+ await install_labels(cls, quiet=False, stdout=stdout)
+ i += 1
+
+ if i:
+ stdout.write("\n")
+
+ stdout.write(f"Finished {i} classes.\n")
+
+ async def install_labels(self, cls, quiet=True, stdout=None):
+ """
+ Setup labels with indexes and constraints for a given class
+
+ :param cls: StructuredNode class
+ :type: class
+ :param quiet: (default true) enable standard output
+ :param stdout: stdout stream
+ :type: bool
+ :return: None
+ """
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ if not hasattr(cls, "__label__"):
+ if not quiet:
+ stdout.write(
+ f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n"
+ )
+ return
+
+ for name, property in cls.defined_properties(aliases=False, rels=False).items():
+ await self._install_node(cls, name, property, quiet, stdout)
+
+ for _, relationship in cls.defined_properties(
+ aliases=False, rels=True, properties=False
+ ).items():
+ await self._install_relationship(cls, relationship, quiet, stdout)
+
+ async def _create_node_index(self, label: str, property_name: str, stdout):
+ try:
+ await self.cypher_query(
+ f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); "
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ INDEX_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+
+ async def _create_node_constraint(self, label: str, property_name: str, stdout):
+ try:
+ await self.cypher_query(
+ f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name}
+ FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE"""
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ CONSTRAINT_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+
+ async def _create_relationship_index(
+ self, relationship_type: str, property_name: str, stdout
+ ):
+ try:
+ await self.cypher_query(
+ f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); "
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ INDEX_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+
+ async def _create_relationship_constraint(
+ self, relationship_type: str, property_name: str, stdout
+ ):
+ if await self.version_is_higher_than("5.7"):
+ try:
+ await self.cypher_query(
+ f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name}
+ FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE"""
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ CONSTRAINT_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+ else:
+ raise FeatureNotSupported(
+ f"Unique indexes on relationships are not supported in Neo4j version {await self.database_version}. Please upgrade to Neo4j 5.7 or higher."
+ )
+
+ async def _install_node(self, cls, name, property, quiet, stdout):
+ # Create indexes and constraints for node property
+ db_property = property.db_property or name
+ if property.index:
+ if not quiet:
+ stdout.write(
+ f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n"
+ )
+ await self._create_node_index(
+ label=cls.__label__, property_name=db_property, stdout=stdout
+ )
+
+ elif property.unique_index:
+ if not quiet:
+ stdout.write(
+ f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n"
+ )
+ await self._create_node_constraint(
+ label=cls.__label__, property_name=db_property, stdout=stdout
+ )
+
+ async def _install_relationship(self, cls, relationship, quiet, stdout):
+ # Create indexes and constraints for relationship property
+ relationship_cls = relationship.definition["model"]
+ if relationship_cls is not None:
+ relationship_type = relationship.definition["relation_type"]
+ for prop_name, property in relationship_cls.defined_properties(
+ aliases=False, rels=False
+ ).items():
+ db_property = property.db_property or prop_name
+ if property.index:
+ if not quiet:
+ stdout.write(
+ f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n"
+ )
+ await self._create_relationship_index(
+ relationship_type=relationship_type,
+ property_name=db_property,
+ stdout=stdout,
+ )
+ elif property.unique_index:
+ if not quiet:
+ stdout.write(
+ f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n"
+ )
+ await self._create_relationship_constraint(
+ relationship_type=relationship_type,
+ property_name=db_property,
+ stdout=stdout,
+ )
+
+
+# Create a singleton instance of the database object
+adb = AsyncDatabase()
+
+
+# Deprecated methods
+async def change_neo4j_password(db: AsyncDatabase, user, new_password):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.change_neo4j_password(user, new_password) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await db.change_neo4j_password(user, new_password)
+
+
+async def clear_neo4j_database(
+ db: AsyncDatabase, clear_constraints=False, clear_indexes=False
+):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await db.clear_neo4j_database(clear_constraints, clear_indexes)
+
+
+async def drop_constraints(quiet=True, stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.drop_constraints(quiet, stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await adb.drop_constraints(quiet, stdout)
+
+
+async def drop_indexes(quiet=True, stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.drop_indexes(quiet, stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await adb.drop_indexes(quiet, stdout)
+
+
+async def remove_all_labels(stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.remove_all_labels(stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await adb.remove_all_labels(stdout)
+
+
+async def install_labels(cls, quiet=True, stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.install_labels(cls, quiet, stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await adb.install_labels(cls, quiet, stdout)
+
+
+async def install_all_labels(stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, adb for async).
+ Please use adb.install_all_labels(stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ await adb.install_all_labels(stdout)
+
+
+class AsyncTransactionProxy:
+ bookmarks: Optional[Bookmarks] = None
+
+ def __init__(self, db: AsyncDatabase, access_mode=None):
+ self.db = db
+ self.access_mode = access_mode
+
+ @ensure_connection
+ async def __aenter__(self):
+ print("aenter called")
+ await self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks)
+ self.bookmarks = None
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ print("aexit called")
+ if exc_value:
+ await self.db.rollback()
+
+ if (
+ exc_type is ClientError
+ and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed"
+ ):
+ raise UniqueProperty(exc_value.message)
+
+ if not exc_value:
+ self.last_bookmark = await self.db.commit()
+
+ def __call__(self, func):
+ if AsyncUtil.is_async_code and not iscoroutinefunction(func):
+ raise TypeError(NOT_COROUTINE_ERROR)
+
+ async def wrapper(*args, **kwargs):
+ async with self:
+ print("call called")
+ return await func(*args, **kwargs)
+
+ return wrapper
+
+ @property
+ def with_bookmark(self):
+ return BookmarkingAsyncTransactionProxy(self.db, self.access_mode)
+
+
+class BookmarkingAsyncTransactionProxy(AsyncTransactionProxy):
+ def __call__(self, func):
+ if AsyncUtil.is_async_code and not iscoroutinefunction(func):
+ raise TypeError(NOT_COROUTINE_ERROR)
+
+ async def wrapper(*args, **kwargs):
+ self.bookmarks = kwargs.pop("bookmarks", None)
+
+ async with self:
+ result = await func(*args, **kwargs)
+ self.last_bookmark = None
+
+ return result, self.last_bookmark
+
+ return wrapper
+
+
+class ImpersonationHandler:
+ def __init__(self, db: AsyncDatabase, impersonated_user: str):
+ self.db = db
+ self.impersonated_user = impersonated_user
+
+ def __enter__(self):
+ self.db.impersonated_user = self.impersonated_user
+ return self
+
+ def __exit__(self, exception_type, exception_value, exception_traceback):
+ self.db.impersonated_user = None
+
+ print("\nException type:", exception_type)
+ print("\nException value:", exception_value)
+ print("\nTraceback:", exception_traceback)
+
+ def __call__(self, func):
+ def wrapper(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class NodeMeta(type):
+ def __new__(mcs, name, bases, namespace):
+ namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {})
+ cls = super().__new__(mcs, name, bases, namespace)
+ cls.DoesNotExist._model_class = cls
+
+ if hasattr(cls, "__abstract_node__"):
+ delattr(cls, "__abstract_node__")
+ else:
+ if "deleted" in namespace:
+ raise ValueError(
+ "Property name 'deleted' is not allowed as it conflicts with neomodel internals."
+ )
+ elif "id" in namespace:
+ raise ValueError(
+ """
+ Property name 'id' is not allowed as it conflicts with neomodel internals.
+ Consider using 'uid' or 'identifier' as id is also a Neo4j internal.
+ """
+ )
+ elif "element_id" in namespace:
+ raise ValueError(
+ """
+ Property name 'element_id' is not allowed as it conflicts with neomodel internals.
+ Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal.
+ """
+ )
+ for key, value in (
+ (x, y) for x, y in namespace.items() if isinstance(y, Property)
+ ):
+ value.name, value.owner = key, cls
+ if hasattr(value, "setup") and callable(value.setup):
+ value.setup()
+
+ # cache various groups of properies
+ cls.__required_properties__ = tuple(
+ name
+ for name, property in cls.defined_properties(
+ aliases=False, rels=False
+ ).items()
+ if property.required or property.unique_index
+ )
+ cls.__all_properties__ = tuple(
+ cls.defined_properties(aliases=False, rels=False).items()
+ )
+ cls.__all_aliases__ = tuple(
+ cls.defined_properties(properties=False, rels=False).items()
+ )
+ cls.__all_relationships__ = tuple(
+ cls.defined_properties(aliases=False, properties=False).items()
+ )
+
+ cls.__label__ = namespace.get("__label__", name)
+ cls.__optional_labels__ = namespace.get("__optional_labels__", [])
+
+ build_class_registry(cls)
+
+ return cls
+
+
+def build_class_registry(cls):
+ base_label_set = frozenset(cls.inherited_labels())
+ optional_label_set = set(cls.inherited_optional_labels())
+
+ # Construct all possible combinations of labels + optional labels
+ possible_label_combinations = [
+ frozenset(set(x).union(base_label_set))
+ for i in range(1, len(optional_label_set) + 1)
+ for x in combinations(optional_label_set, i)
+ ]
+ possible_label_combinations.append(base_label_set)
+
+ for label_set in possible_label_combinations:
+ if label_set not in adb._NODE_CLASS_REGISTRY:
+ adb._NODE_CLASS_REGISTRY[label_set] = cls
+ else:
+ raise NodeClassAlreadyDefined(cls, adb._NODE_CLASS_REGISTRY)
+
+
+NodeBase = NodeMeta("NodeBase", (AsyncPropertyManager,), {"__abstract_node__": True})
+
+
+class AsyncStructuredNode(NodeBase):
+ """
+ Base class for all node definitions to inherit from.
+
+ If you want to create your own abstract classes set:
+ __abstract_node__ = True
+ """
+
+ # static properties
+
+ __abstract_node__ = True
+
+ # magic methods
+
+ def __init__(self, *args, **kwargs):
+ if "deleted" in kwargs:
+ raise ValueError("deleted property is reserved for neomodel")
+
+ for key, val in self.__all_relationships__:
+ self.__dict__[key] = val.build_manager(self, key)
+
+ super().__init__(*args, **kwargs)
+
+ def __eq__(self, other):
+ if not isinstance(other, (AsyncStructuredNode,)):
+ return False
+ if hasattr(self, "element_id") and hasattr(other, "element_id"):
+ return self.element_id == other.element_id
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __repr__(self):
+ return f"<{self.__class__.__name__}: {self}>"
+
+ def __str__(self):
+ return repr(self.__properties__)
+
+ # dynamic properties
+
+ @classproperty
+ def nodes(cls):
+ """
+ Returns a NodeSet object representing all nodes of the classes label
+ :return: NodeSet
+ :rtype: NodeSet
+ """
+ from neomodel.async_.match import AsyncNodeSet
+
+ return AsyncNodeSet(cls)
+
+ @property
+ def element_id(self):
+ if hasattr(self, "element_id_property"):
+ return self.element_id_property
+ return None
+
+ # Version 4.4 support - id is deprecated in version 5.x
+ @property
+ def id(self):
+ try:
+ return int(self.element_id_property)
+ except (TypeError, ValueError):
+ raise ValueError(
+ "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
+ )
+
+ # methods
+
+ @classmethod
+ async def _build_merge_query(
+ cls, merge_params, update_existing=False, lazy=False, relationship=None
+ ):
+ """
+ Get a tuple of a CYPHER query and a params dict for the specified MERGE query.
+
+ :param merge_params: The target node match parameters, each node must have a "create" key and optional "update".
+ :type merge_params: list of dict
+ :param update_existing: True to update properties of existing nodes, default False to keep existing values.
+ :type update_existing: bool
+ :rtype: tuple
+ """
+ query_params = dict(merge_params=merge_params)
+ n_merge_labels = ":".join(cls.inherited_labels())
+ n_merge_prm = ", ".join(
+ (
+ f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}"
+ for p in cls.__required_properties__
+ )
+ )
+ n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}"
+ if relationship is None:
+ # create "simple" unwind query
+ query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n "
+ else:
+ # validate relationship
+ if not isinstance(relationship.source, AsyncStructuredNode):
+ raise ValueError(
+ f"relationship source [{repr(relationship.source)}] is not a StructuredNode"
+ )
+ relation_type = relationship.definition.get("relation_type")
+ if not relation_type:
+ raise ValueError(
+ "No relation_type is specified on provided relationship"
+ )
+
+ from neomodel.async_.match import _rel_helper
+
+ query_params["source_id"] = await adb.parse_element_id(
+ relationship.source.element_id
+ )
+ query = f"MATCH (source:{relationship.source.__label__}) WHERE {await adb.get_id_method()}(source) = $source_id\n "
+ query += "WITH source\n UNWIND $merge_params as params \n "
+ query += "MERGE "
+ query += _rel_helper(
+ lhs="source",
+ rhs=n_merge,
+ ident=None,
+ relation_type=relation_type,
+ direction=relationship.definition["direction"],
+ )
+
+ query += "ON CREATE SET n = params.create\n "
+ # if update_existing, write properties on match as well
+ if update_existing is True:
+ query += "ON MATCH SET n += params.update\n"
+
+ # close query
+ if lazy:
+ query += f"RETURN {await adb.get_id_method()}(n)"
+ else:
+ query += "RETURN n"
+
+ return query, query_params
+
+ @classmethod
+ async def create(cls, *props, **kwargs):
+ """
+ Call to CREATE with parameters map. A new instance will be created and saved.
+
+ :param props: dict of properties to create the nodes.
+ :type props: tuple
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :type: bool
+ :rtype: list
+ """
+
+ if "streaming" in kwargs:
+ warnings.warn(
+ STREAMING_WARNING,
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+
+ lazy = kwargs.get("lazy", False)
+ # create mapped query
+ query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)"
+
+ # close query
+ if lazy:
+ query += f" RETURN {await adb.get_id_method()}(n)"
+ else:
+ query += " RETURN n"
+
+ results = []
+ for item in [
+ cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props
+ ]:
+ node, _ = await adb.cypher_query(query, {"create_params": item})
+ results.extend(node[0])
+
+ nodes = [cls.inflate(node) for node in results]
+
+ if not lazy and hasattr(cls, "post_create"):
+ for node in nodes:
+ node.post_create()
+
+ return nodes
+
+ @classmethod
+ async def create_or_update(cls, *props, **kwargs):
+ """
+ Call to MERGE with parameters map. A new instance will be created and saved if does not already exists,
+ this is an atomic operation. If an instance already exists all optional properties specified will be updated.
+
+ Note that the post_create hook isn't called after create_or_update
+
+ :param props: List of dict arguments to get or create the entities with.
+ :type props: tuple
+ :param relationship: Optional, relationship to get/create on when new entity is created.
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :rtype: list
+ """
+ lazy = kwargs.get("lazy", False)
+ relationship = kwargs.get("relationship")
+
+ # build merge query, make sure to update only explicitly specified properties
+ create_or_update_params = []
+ for specified, deflated in [
+ (p, cls.deflate(p, skip_empty=True)) for p in props
+ ]:
+ create_or_update_params.append(
+ {
+ "create": deflated,
+ "update": dict(
+ (k, v) for k, v in deflated.items() if k in specified
+ ),
+ }
+ )
+ query, params = await cls._build_merge_query(
+ create_or_update_params,
+ update_existing=True,
+ relationship=relationship,
+ lazy=lazy,
+ )
+
+ if "streaming" in kwargs:
+ warnings.warn(
+ STREAMING_WARNING,
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+
+ # fetch and build instance for each result
+ results = await adb.cypher_query(query, params)
+ return [cls.inflate(r[0]) for r in results[0]]
+
+ async def cypher(self, query, params=None):
+ """
+ Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id.
+
+ :param query: cypher query string
+ :type: string
+ :param params: query parameters
+ :type: dict
+ :return: list containing query results
+ :rtype: list
+ """
+ self._pre_action_check("cypher")
+ params = params or {}
+ element_id = await adb.parse_element_id(self.element_id)
+ params.update({"self": element_id})
+ return await adb.cypher_query(query, params)
+
+ @hooks
+ async def delete(self):
+ """
+ Delete a node and its relationships
+
+ :return: True
+ """
+ self._pre_action_check("delete")
+ await self.cypher(
+ f"MATCH (self) WHERE {await adb.get_id_method()}(self)=$self DETACH DELETE self"
+ )
+ delattr(self, "element_id_property")
+ self.deleted = True
+ return True
+
+ @classmethod
+ async def get_or_create(cls, *props, **kwargs):
+ """
+ Call to MERGE with parameters map. A new instance will be created and saved if does not already exist,
+ this is an atomic operation.
+ Parameters must contain all required properties, any non required properties with defaults will be generated.
+
+ Note that the post_create hook isn't called after get_or_create
+
+ :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create
+ the entities with.
+ :type props: tuple
+ :param relationship: Optional, relationship to get/create on when new entity is created.
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :rtype: list
+ """
+ lazy = kwargs.get("lazy", False)
+ relationship = kwargs.get("relationship")
+
+ # build merge query
+ get_or_create_params = [
+ {"create": cls.deflate(p, skip_empty=True)} for p in props
+ ]
+ query, params = await cls._build_merge_query(
+ get_or_create_params, relationship=relationship, lazy=lazy
+ )
+
+ if "streaming" in kwargs:
+ warnings.warn(
+ STREAMING_WARNING,
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+
+ # fetch and build instance for each result
+ results = await adb.cypher_query(query, params)
+ return [cls.inflate(r[0]) for r in results[0]]
+
+ @classmethod
+ def inflate(cls, node):
+ """
+ Inflate a raw neo4j_driver node to a neomodel node
+ :param node:
+ :return: node object
+ """
+ # support lazy loading
+ if isinstance(node, str) or isinstance(node, int):
+ snode = cls()
+ snode.element_id_property = node
+ else:
+ node_properties = _get_node_properties(node)
+ props = {}
+ for key, prop in cls.__all_properties__:
+ # map property name from database to object property
+ db_property = prop.db_property or key
+
+ if db_property in node_properties:
+ props[key] = prop.inflate(node_properties[db_property], node)
+ elif prop.has_default:
+ props[key] = prop.default_value()
+ else:
+ props[key] = None
+
+ snode = cls(**props)
+ snode.element_id_property = node.element_id
+
+ return snode
+
+ @classmethod
+ def inherited_labels(cls):
+ """
+ Return list of labels from nodes class hierarchy.
+
+ :return: list
+ """
+ return [
+ scls.__label__
+ for scls in cls.mro()
+ if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__")
+ ]
+
+ @classmethod
+ def inherited_optional_labels(cls):
+ """
+ Return list of optional labels from nodes class hierarchy.
+
+ :return: list
+ :rtype: list
+ """
+ return [
+ label
+ for scls in cls.mro()
+ for label in getattr(scls, "__optional_labels__", [])
+ if not hasattr(scls, "__abstract_node__")
+ ]
+
+ async def labels(self):
+ """
+ Returns list of labels tied to the node from neo4j.
+
+ :return: list of labels
+ :rtype: list
+ """
+ self._pre_action_check("labels")
+ result = await self.cypher(
+ f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self " "RETURN labels(n)"
+ )
+ return result[0][0][0]
+
+ def _pre_action_check(self, action):
+ if hasattr(self, "deleted") and self.deleted:
+ raise ValueError(
+ f"{self.__class__.__name__}.{action}() attempted on deleted node"
+ )
+ if not hasattr(self, "element_id"):
+ raise ValueError(
+ f"{self.__class__.__name__}.{action}() attempted on unsaved node"
+ )
+
+ async def refresh(self):
+ """
+ Reload the node from neo4j
+ """
+ self._pre_action_check("refresh")
+ if hasattr(self, "element_id"):
+ results = await self.cypher(
+ f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self RETURN n"
+ )
+ request = results[0]
+ if not request or not request[0]:
+ raise self.__class__.DoesNotExist("Can't refresh non existent node")
+ node = self.inflate(request[0][0])
+ for key, val in node.__properties__.items():
+ setattr(self, key, val)
+ else:
+ raise ValueError("Can't refresh unsaved node")
+
+ @hooks
+ async def save(self):
+ """
+ Save the node to neo4j or raise an exception
+
+ :return: the node instance
+ """
+
+ # create or update instance node
+ if hasattr(self, "element_id_property"):
+ # update
+ params = self.deflate(self.__properties__, self)
+ query = f"MATCH (n) WHERE {await adb.get_id_method()}(n)=$self\n"
+
+ if params:
+ query += "SET "
+ query += ",\n".join([f"n.{key} = ${key}" for key in params])
+ query += "\n"
+ if self.inherited_labels():
+ query += "\n".join(
+ [f"SET n:`{label}`" for label in self.inherited_labels()]
+ )
+ await self.cypher(query, params)
+ elif hasattr(self, "deleted") and self.deleted:
+ raise ValueError(
+ f"{self.__class__.__name__}.save() attempted on deleted node"
+ )
+ else: # create
+ result = await self.create(self.__properties__)
+ created_node = result[0]
+ self.element_id_property = created_node.element_id
+ return self
diff --git a/neomodel/async_/match.py b/neomodel/async_/match.py
new file mode 100644
index 00000000..e7e82adb
--- /dev/null
+++ b/neomodel/async_/match.py
@@ -0,0 +1,1068 @@
+import inspect
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Optional
+
+from neomodel.async_.core import AsyncStructuredNode, adb
+from neomodel.exceptions import MultipleNodesReturned
+from neomodel.match_q import Q, QBase
+from neomodel.properties import AliasProperty
+from neomodel.util import INCOMING, OUTGOING
+
+
+def _rel_helper(
+ lhs,
+ rhs,
+ ident=None,
+ relation_type=None,
+ direction=None,
+ relation_properties=None,
+ **kwargs, # NOSONAR
+):
+ """
+ Generate a relationship matching string, with specified parameters.
+ Examples:
+ relation_direction = OUTGOING: (lhs)-[relation_ident:relation_type]->(rhs)
+ relation_direction = INCOMING: (lhs)<-[relation_ident:relation_type]-(rhs)
+ relation_direction = EITHER: (lhs)-[relation_ident:relation_type]-(rhs)
+
+ :param lhs: The left hand statement.
+ :type lhs: str
+ :param rhs: The right hand statement.
+ :type rhs: str
+ :param ident: A specific identity to name the relationship, or None.
+ :type ident: str
+ :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel.
+ :type relation_type: str
+ :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING.
+ :param relation_properties: dictionary of relationship properties to match
+ :returns: string
+ """
+ rel_props = ""
+
+ if relation_properties:
+ rel_props_str = ", ".join(
+ (f"{key}: {value}" for key, value in relation_properties.items())
+ )
+ rel_props = f" {{{rel_props_str}}}"
+
+ rel_def = ""
+ # relation_type is unspecified
+ if relation_type is None:
+ rel_def = ""
+ # all("*" wildcard) relation_type
+ elif relation_type == "*":
+ rel_def = "[*]"
+ else:
+ # explicit relation_type
+ rel_def = f"[{ident if ident else ''}:`{relation_type}`{rel_props}]"
+
+ stmt = ""
+ if direction == OUTGOING:
+ stmt = f"-{rel_def}->"
+ elif direction == INCOMING:
+ stmt = f"<-{rel_def}-"
+ else:
+ stmt = f"-{rel_def}-"
+
+ # Make sure not to add parenthesis when they are already present
+ if lhs[-1] != ")":
+ lhs = f"({lhs})"
+ if rhs[-1] != ")":
+ rhs = f"({rhs})"
+
+ return f"{lhs}{stmt}{rhs}"
+
+
+def _rel_merge_helper(
+ lhs,
+ rhs,
+ ident="neomodelident",
+ relation_type=None,
+ direction=None,
+ relation_properties=None,
+ **kwargs, # NOSONAR
+):
+ """
+ Generate a relationship merging string, with specified parameters.
+ Examples:
+ relation_direction = OUTGOING: (lhs)-[relation_ident:relation_type]->(rhs)
+ relation_direction = INCOMING: (lhs)<-[relation_ident:relation_type]-(rhs)
+ relation_direction = EITHER: (lhs)-[relation_ident:relation_type]-(rhs)
+
+ :param lhs: The left hand statement.
+ :type lhs: str
+ :param rhs: The right hand statement.
+ :type rhs: str
+ :param ident: A specific identity to name the relationship, or None.
+ :type ident: str
+ :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel.
+ :type relation_type: str
+ :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING.
+ :param relation_properties: dictionary of relationship properties to merge
+ :returns: string
+ """
+
+ if direction == OUTGOING:
+ stmt = "-{0}->"
+ elif direction == INCOMING:
+ stmt = "<-{0}-"
+ else:
+ stmt = "-{0}-"
+
+ rel_props = ""
+ rel_none_props = ""
+
+ if relation_properties:
+ rel_props_str = ", ".join(
+ (
+ f"{key}: {value}"
+ for key, value in relation_properties.items()
+ if value is not None
+ )
+ )
+ rel_props = f" {{{rel_props_str}}}"
+ if None in relation_properties.values():
+ rel_prop_val_str = ", ".join(
+ (
+ f"{ident}.{key}=${key!s}"
+ for key, value in relation_properties.items()
+ if value is None
+ )
+ )
+ rel_none_props = (
+ f" ON CREATE SET {rel_prop_val_str} ON MATCH SET {rel_prop_val_str}"
+ )
+ # relation_type is unspecified
+ if relation_type is None:
+ stmt = stmt.format("")
+ # all("*" wildcard) relation_type
+ elif relation_type == "*":
+ stmt = stmt.format("[*]")
+ else:
+ # explicit relation_type
+ stmt = stmt.format(f"[{ident}:`{relation_type}`{rel_props}]")
+
+ return f"({lhs}){stmt}({rhs}){rel_none_props}"
+
+
+# special operators
+_SPECIAL_OPERATOR_IN = "IN"
+_SPECIAL_OPERATOR_INSENSITIVE = "(?i)"
+_SPECIAL_OPERATOR_ISNULL = "IS NULL"
+_SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL"
+_SPECIAL_OPERATOR_REGEX = "=~"
+
+_UNARY_OPERATORS = (_SPECIAL_OPERATOR_ISNULL, _SPECIAL_OPERATOR_ISNOTNULL)
+
+_REGEX_INSESITIVE = _SPECIAL_OPERATOR_INSENSITIVE + "{}"
+_REGEX_CONTAINS = ".*{}.*"
+_REGEX_STARTSWITH = "{}.*"
+_REGEX_ENDSWITH = ".*{}"
+
+# regex operations that require escaping
+_STRING_REGEX_OPERATOR_TABLE = {
+ "iexact": _REGEX_INSESITIVE,
+ "contains": _REGEX_CONTAINS,
+ "icontains": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_CONTAINS,
+ "startswith": _REGEX_STARTSWITH,
+ "istartswith": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_STARTSWITH,
+ "endswith": _REGEX_ENDSWITH,
+ "iendswith": _SPECIAL_OPERATOR_INSENSITIVE + _REGEX_ENDSWITH,
+}
+# regex operations that do not require escaping
+_REGEX_OPERATOR_TABLE = {
+ "iregex": _REGEX_INSESITIVE,
+}
+# list all regex operations, these will require formatting of the value
+_REGEX_OPERATOR_TABLE.update(_STRING_REGEX_OPERATOR_TABLE)
+
+# list all supported operators
+OPERATOR_TABLE = {
+ "lt": "<",
+ "gt": ">",
+ "lte": "<=",
+ "gte": ">=",
+ "ne": "<>",
+ "in": _SPECIAL_OPERATOR_IN,
+ "isnull": _SPECIAL_OPERATOR_ISNULL,
+ "regex": _SPECIAL_OPERATOR_REGEX,
+ "exact": "=",
+}
+# add all regex operators
+OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE)
+
+
+def install_traversals(cls, node_set):
+ """
+ For a StructuredNode class install Traversal objects for each
+ relationship definition on a NodeSet instance
+ """
+ rels = cls.defined_properties(rels=True, aliases=False, properties=False)
+
+ for key in rels.keys():
+ if hasattr(node_set, key):
+ raise ValueError(f"Cannot install traversal '{key}' exists on NodeSet")
+
+ rel = getattr(cls, key)
+ rel.lookup_node_class()
+
+ traversal = AsyncTraversal(source=node_set, name=key, definition=rel.definition)
+ setattr(node_set, key, traversal)
+
+
+def process_filter_args(cls, kwargs):
+ """
+ loop through properties in filter parameters check they match class definition
+ deflate them and convert into something easy to generate cypher from
+ """
+
+ output = {}
+
+ for key, value in kwargs.items():
+ if "__" in key:
+ prop, operator = key.rsplit("__")
+ operator = OPERATOR_TABLE[operator]
+ else:
+ prop = key
+ operator = "="
+
+ if prop not in cls.defined_properties(rels=False):
+ raise ValueError(
+ f"No such property {prop} on {cls.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
+ )
+
+ property_obj = getattr(cls, prop)
+ if isinstance(property_obj, AliasProperty):
+ prop = property_obj.aliased_to()
+ deflated_value = getattr(cls, prop).deflate(value)
+ else:
+ operator, deflated_value = transform_operator_to_filter(
+ operator=operator,
+ filter_key=key,
+ filter_value=value,
+ property_obj=property_obj,
+ )
+
+ # map property to correct property name in the database
+ db_property = cls.defined_properties(rels=False)[prop].db_property or prop
+
+ output[db_property] = (operator, deflated_value)
+
+ return output
+
+
+def transform_operator_to_filter(operator, filter_key, filter_value, property_obj):
+ # handle special operators
+ if operator == _SPECIAL_OPERATOR_IN:
+ if not isinstance(filter_value, tuple) and not isinstance(filter_value, list):
+ raise ValueError(
+ f"Value must be a tuple or list for IN operation {filter_key}={filter_value}"
+ )
+ deflated_value = [property_obj.deflate(v) for v in filter_value]
+ elif operator == _SPECIAL_OPERATOR_ISNULL:
+ if not isinstance(filter_value, bool):
+ raise ValueError(
+ f"Value must be a bool for isnull operation on {filter_key}"
+ )
+ operator = "IS NULL" if filter_value else "IS NOT NULL"
+ deflated_value = None
+ elif operator in _REGEX_OPERATOR_TABLE.values():
+ deflated_value = property_obj.deflate(filter_value)
+ if not isinstance(deflated_value, str):
+ raise ValueError(f"Must be a string value for {filter_key}")
+ if operator in _STRING_REGEX_OPERATOR_TABLE.values():
+ deflated_value = re.escape(deflated_value)
+ deflated_value = operator.format(deflated_value)
+ operator = _SPECIAL_OPERATOR_REGEX
+ else:
+ deflated_value = property_obj.deflate(filter_value)
+
+ return operator, deflated_value
+
+
+def process_has_args(cls, kwargs):
+ """
+ loop through has parameters check they correspond to class rels defined
+ """
+ rel_definitions = cls.defined_properties(properties=False, rels=True, aliases=False)
+
+ match, dont_match = {}, {}
+
+ for key, value in kwargs.items():
+ if key not in rel_definitions:
+ raise ValueError(f"No such relation {key} defined on a {cls.__name__}")
+
+ rhs_ident = key
+
+ rel_definitions[key].lookup_node_class()
+
+ if value is True:
+ match[rhs_ident] = rel_definitions[key].definition
+ elif value is False:
+ dont_match[rhs_ident] = rel_definitions[key].definition
+ elif isinstance(value, AsyncNodeSet):
+ raise NotImplementedError("Not implemented yet")
+ else:
+ raise ValueError("Expecting True / False / NodeSet got: " + repr(value))
+
+ return match, dont_match
+
+
+class QueryAST:
+ match: Optional[list]
+ optional_match: Optional[list]
+ where: Optional[list]
+ with_clause: Optional[str]
+ return_clause: Optional[str]
+ order_by: Optional[str]
+ skip: Optional[int]
+ limit: Optional[int]
+ result_class: Optional[type]
+ lookup: Optional[str]
+ additional_return: Optional[list]
+ is_count: Optional[bool]
+
+ def __init__(
+ self,
+ match: Optional[list] = None,
+ optional_match: Optional[list] = None,
+ where: Optional[list] = None,
+ with_clause: Optional[str] = None,
+ return_clause: Optional[str] = None,
+ order_by: Optional[str] = None,
+ skip: Optional[int] = None,
+ limit: Optional[int] = None,
+ result_class: Optional[type] = None,
+ lookup: Optional[str] = None,
+ additional_return: Optional[list] = None,
+ is_count: Optional[bool] = False,
+ ):
+ self.match = match if match else []
+ self.optional_match = optional_match if optional_match else []
+ self.where = where if where else []
+ self.with_clause = with_clause
+ self.return_clause = return_clause
+ self.order_by = order_by
+ self.skip = skip
+ self.limit = limit
+ self.result_class = result_class
+ self.lookup = lookup
+ self.additional_return = additional_return if additional_return else []
+ self.is_count = is_count
+
+
+class AsyncQueryBuilder:
+ def __init__(self, node_set):
+ self.node_set = node_set
+ self._ast = QueryAST()
+ self._query_params = {}
+ self._place_holder_registry = {}
+ self._ident_count = 0
+ self._node_counters = defaultdict(int)
+
+ async def build_ast(self):
+ if hasattr(self.node_set, "relations_to_fetch"):
+ for relation in self.node_set.relations_to_fetch:
+ self.build_traversal_from_path(relation, self.node_set.source)
+
+ await self.build_source(self.node_set)
+
+ if hasattr(self.node_set, "skip"):
+ self._ast.skip = self.node_set.skip
+ if hasattr(self.node_set, "limit"):
+ self._ast.limit = self.node_set.limit
+
+ return self
+
+ async def build_source(self, source):
+ if isinstance(source, AsyncTraversal):
+ return await self.build_traversal(source)
+ if isinstance(source, AsyncNodeSet):
+ if inspect.isclass(source.source) and issubclass(
+ source.source, AsyncStructuredNode
+ ):
+ ident = self.build_label(source.source.__label__.lower(), source.source)
+ else:
+ ident = await self.build_source(source.source)
+
+ self.build_additional_match(ident, source)
+
+ if hasattr(source, "order_by_elements"):
+ self.build_order_by(ident, source)
+
+ if source.filters or source.q_filters:
+ self.build_where_stmt(
+ ident,
+ source.filters,
+ source.q_filters,
+ source_class=source.source_class,
+ )
+
+ return ident
+ if isinstance(source, AsyncStructuredNode):
+ return await self.build_node(source)
+ raise ValueError("Unknown source type " + repr(source))
+
+ def create_ident(self):
+ self._ident_count += 1
+ return "r" + str(self._ident_count)
+
+ def build_order_by(self, ident, source):
+ if "?" in source.order_by_elements:
+ self._ast.with_clause = f"{ident}, rand() as r"
+ self._ast.order_by = "r"
+ else:
+ self._ast.order_by = [f"{ident}.{p}" for p in source.order_by_elements]
+
+ async def build_traversal(self, traversal):
+ """
+ traverse a relationship from a node to a set of nodes
+ """
+ # build source
+ rhs_label = ":" + traversal.target_class.__label__
+
+ # build source
+ rel_ident = self.create_ident()
+ lhs_ident = await self.build_source(traversal.source)
+ traversal_ident = f"{traversal.name}_{rel_ident}"
+ rhs_ident = traversal_ident + rhs_label
+ self._ast.return_clause = traversal_ident
+ self._ast.result_class = traversal.target_class
+
+ stmt = _rel_helper(
+ lhs=lhs_ident,
+ rhs=rhs_ident,
+ ident=rel_ident,
+ **traversal.definition,
+ )
+ self._ast.match.append(stmt)
+
+ if traversal.filters:
+ self.build_where_stmt(rel_ident, traversal.filters)
+
+ return traversal_ident
+
+ def _additional_return(self, name):
+ if name not in self._ast.additional_return and name != self._ast.return_clause:
+ self._ast.additional_return.append(name)
+
+ def build_traversal_from_path(self, relation: dict, source_class) -> str:
+ path: str = relation["path"]
+ stmt: str = ""
+ source_class_iterator = source_class
+ for index, part in enumerate(path.split("__")):
+ relationship = getattr(source_class_iterator, part)
+ # build source
+ if "node_class" not in relationship.definition:
+ relationship.lookup_node_class()
+ rhs_label = relationship.definition["node_class"].__label__
+ rel_reference = f'{relationship.definition["node_class"]}_{part}'
+ self._node_counters[rel_reference] += 1
+ rhs_name = (
+ f"{rhs_label.lower()}_{part}_{self._node_counters[rel_reference]}"
+ )
+ rhs_ident = f"{rhs_name}:{rhs_label}"
+ self._additional_return(rhs_name)
+ if not stmt:
+ lhs_label = source_class_iterator.__label__
+ lhs_name = lhs_label.lower()
+ lhs_ident = f"{lhs_name}:{lhs_label}"
+ if not index:
+ # This is the first one, we make sure that 'return'
+ # contains the primary node so _contains() works
+ # as usual
+ self._ast.return_clause = lhs_name
+ else:
+ self._additional_return(lhs_name)
+ else:
+ lhs_ident = stmt
+
+ rel_ident = self.create_ident()
+ self._additional_return(rel_ident)
+ stmt = _rel_helper(
+ lhs=lhs_ident,
+ rhs=rhs_ident,
+ ident=rel_ident,
+ direction=relationship.definition["direction"],
+ relation_type=relationship.definition["relation_type"],
+ )
+ source_class_iterator = relationship.definition["node_class"]
+
+ if relation.get("optional"):
+ self._ast.optional_match.append(stmt)
+ else:
+ self._ast.match.append(stmt)
+ return rhs_name
+
+ async def build_node(self, node):
+ ident = node.__class__.__name__.lower()
+ place_holder = self._register_place_holder(ident)
+
+ # Hack to emulate START to lookup a node by id
+ _node_lookup = f"MATCH ({ident}) WHERE {await adb.get_id_method()}({ident})=${place_holder} WITH {ident}"
+ self._ast.lookup = _node_lookup
+
+ self._query_params[place_holder] = await adb.parse_element_id(node.element_id)
+
+ self._ast.return_clause = ident
+ self._ast.result_class = node.__class__
+ return ident
+
+ def build_label(self, ident, cls):
+ """
+ match nodes by a label
+ """
+ ident_w_label = ident + ":" + cls.__label__
+
+ if not self._ast.return_clause and (
+ not self._ast.additional_return or ident not in self._ast.additional_return
+ ):
+ self._ast.match.append(f"({ident_w_label})")
+ self._ast.return_clause = ident
+ self._ast.result_class = cls
+ return ident
+
+ def build_additional_match(self, ident, node_set):
+ """
+ handle additional matches supplied by 'has()' calls
+ """
+ source_ident = ident
+
+ for _, value in node_set.must_match.items():
+ if isinstance(value, dict):
+ label = ":" + value["node_class"].__label__
+ stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **value)
+ self._ast.where.append(stmt)
+ else:
+ raise ValueError("Expecting dict got: " + repr(value))
+
+ for _, val in node_set.dont_match.items():
+ if isinstance(val, dict):
+ label = ":" + val["node_class"].__label__
+ stmt = _rel_helper(lhs=source_ident, rhs=label, ident="", **val)
+ self._ast.where.append("NOT " + stmt)
+ else:
+ raise ValueError("Expecting dict got: " + repr(val))
+
+ def _register_place_holder(self, key):
+ if key in self._place_holder_registry:
+ self._place_holder_registry[key] += 1
+ else:
+ self._place_holder_registry[key] = 1
+ return key + "_" + str(self._place_holder_registry[key])
+
+ def _parse_q_filters(self, ident, q, source_class):
+ target = []
+ for child in q.children:
+ if isinstance(child, QBase):
+ q_childs = self._parse_q_filters(ident, child, source_class)
+ if child.connector == Q.OR:
+ q_childs = "(" + q_childs + ")"
+ target.append(q_childs)
+ else:
+ kwargs = {child[0]: child[1]}
+ filters = process_filter_args(source_class, kwargs)
+ for prop, op_and_val in filters.items():
+ operator, val = op_and_val
+ if operator in _UNARY_OPERATORS:
+ # unary operators do not have a parameter
+ statement = f"{ident}.{prop} {operator}"
+ else:
+ place_holder = self._register_place_holder(ident + "_" + prop)
+ statement = f"{ident}.{prop} {operator} ${place_holder}"
+ self._query_params[place_holder] = val
+ target.append(statement)
+ ret = f" {q.connector} ".join(target)
+ if q.negated:
+ ret = f"NOT ({ret})"
+ return ret
+
+ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
+ """
+ construct a where statement from some filters
+ """
+ if q_filters is not None:
+ stmts = self._parse_q_filters(ident, q_filters, source_class)
+ if stmts:
+ self._ast.where.append(stmts)
+ else:
+ stmts = []
+ for row in filters:
+ negate = False
+
+ # pre-process NOT cases as they are nested dicts
+ if "__NOT__" in row and len(row) == 1:
+ negate = True
+ row = row["__NOT__"]
+
+ for prop, operator_and_val in row.items():
+ operator, val = operator_and_val
+ if operator in _UNARY_OPERATORS:
+ # unary operators do not have a parameter
+ statement = (
+ f"{'NOT' if negate else ''} {ident}.{prop} {operator}"
+ )
+ else:
+ place_holder = self._register_place_holder(ident + "_" + prop)
+ statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}"
+ self._query_params[place_holder] = val
+ stmts.append(statement)
+
+ self._ast.where.append(" AND ".join(stmts))
+
+ def build_query(self):
+ query = ""
+
+ if self._ast.lookup:
+ query += self._ast.lookup
+
+ # Instead of using only one MATCH statement for every relation
+ # to follow, we use one MATCH per relation (to avoid cartesian
+ # product issues...).
+ # There might be optimizations to be done, using projections,
+ # or pusing patterns instead of a chain of OPTIONAL MATCH.
+ if self._ast.match:
+ query += " MATCH "
+ query += " MATCH ".join(i for i in self._ast.match)
+
+ if self._ast.optional_match:
+ query += " OPTIONAL MATCH "
+ query += " OPTIONAL MATCH ".join(i for i in self._ast.optional_match)
+
+ if self._ast.where:
+ query += " WHERE "
+ query += " AND ".join(self._ast.where)
+
+ if self._ast.with_clause:
+ query += " WITH "
+ query += self._ast.with_clause
+
+ query += " RETURN "
+ if self._ast.return_clause:
+ query += self._ast.return_clause
+ if self._ast.additional_return:
+ if self._ast.return_clause:
+ query += ", "
+ query += ", ".join(self._ast.additional_return)
+
+ if self._ast.order_by:
+ query += " ORDER BY "
+ query += ", ".join(self._ast.order_by)
+
+ # If we return a count with pagination, pagination has to happen before RETURN
+ # It will then be included in the WITH clause already
+ if self._ast.skip and not self._ast.is_count:
+ query += f" SKIP {self._ast.skip}"
+
+ if self._ast.limit and not self._ast.is_count:
+ query += f" LIMIT {self._ast.limit}"
+
+ return query
+
+ async def _count(self):
+ self._ast.is_count = True
+ # If we return a count with pagination, pagination has to happen before RETURN
+ # Like : WITH my_var SKIP 10 LIMIT 10 RETURN count(my_var)
+ self._ast.with_clause = f"{self._ast.return_clause}"
+ if self._ast.skip:
+ self._ast.with_clause += f" SKIP {self._ast.skip}"
+
+ if self._ast.limit:
+ self._ast.with_clause += f" LIMIT {self._ast.limit}"
+
+ self._ast.return_clause = f"count({self._ast.return_clause})"
+ # drop order_by, results in an invalid query
+ self._ast.order_by = None
+ # drop additional_return to avoid unexpected result
+ self._ast.additional_return = None
+ query = self.build_query()
+ results, _ = await adb.cypher_query(query, self._query_params)
+ return int(results[0][0])
+
+ async def _contains(self, node_element_id):
+ # inject id = into ast
+ if not self._ast.return_clause:
+ print(self._ast.additional_return)
+ self._ast.return_clause = self._ast.additional_return[0]
+ ident = self._ast.return_clause
+ place_holder = self._register_place_holder(ident + "_contains")
+ self._ast.where.append(
+ f"{await adb.get_id_method()}({ident}) = ${place_holder}"
+ )
+ self._query_params[place_holder] = node_element_id
+ return await self._count() >= 1
+
+ async def _execute(self, lazy=False):
+ if lazy:
+ # inject id() into return or return_set
+ if self._ast.return_clause:
+ self._ast.return_clause = (
+ f"{await adb.get_id_method()}({self._ast.return_clause})"
+ )
+ else:
+ self._ast.additional_return = [
+ f"{await adb.get_id_method()}({item})"
+ for item in self._ast.additional_return
+ ]
+ query = self.build_query()
+ results, _ = await adb.cypher_query(
+ query, self._query_params, resolve_objects=True
+ )
+ # The following is not as elegant as it could be but had to be copied from the
+ # version prior to cypher_query with the resolve_objects capability.
+ # It seems that certain calls are only supposed to be focusing to the first
+ # result item returned (?)
+ if results and len(results[0]) == 1:
+ return [n[0] for n in results]
+ return results
+
+
+class AsyncBaseSet:
+ """
+ Base class for all node sets.
+
+ Contains common python magic methods, __len__, __contains__ etc
+ """
+
+ query_cls = AsyncQueryBuilder
+
+ async def all(self, lazy=False):
+ """
+ Return all nodes belonging to the set
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :return: list of nodes
+ :rtype: list
+ """
+ ast = await self.query_cls(self).build_ast()
+ return await ast._execute(lazy)
+
+ async def __aiter__(self):
+ ast = await self.query_cls(self).build_ast()
+ async for i in await ast._execute():
+ yield i
+
+ async def get_len(self):
+ ast = await self.query_cls(self).build_ast()
+ return await ast._count()
+
+ async def check_bool(self):
+ """
+ Override for __bool__ dunder method.
+ :return: True if the set contains any nodes, False otherwise
+ :rtype: bool
+ """
+ ast = await self.query_cls(self).build_ast()
+ _count = await ast._count()
+ return _count > 0
+
+ async def check_nonzero(self):
+ """
+ Override for __bool__ dunder method.
+ :return: True if the set contains any node, False otherwise
+ :rtype: bool
+ """
+ return await self.check_bool()
+
+ async def check_contains(self, obj):
+ if isinstance(obj, AsyncStructuredNode):
+ if hasattr(obj, "element_id") and obj.element_id is not None:
+ ast = await self.query_cls(self).build_ast()
+ obj_element_id = await adb.parse_element_id(obj.element_id)
+ return await ast._contains(obj_element_id)
+ raise ValueError("Unsaved node: " + repr(obj))
+
+ raise ValueError("Expecting StructuredNode instance")
+
+ async def get_item(self, key):
+ if isinstance(key, slice):
+ if key.stop and key.start:
+ self.limit = key.stop - key.start
+ self.skip = key.start
+ elif key.stop:
+ self.limit = key.stop
+ elif key.start:
+ self.skip = key.start
+
+ return self
+
+ if isinstance(key, int):
+ self.skip = key
+ self.limit = 1
+
+ ast = await self.query_cls(self).build_ast()
+ _items = ast._execute()
+ return _items[0]
+
+ return None
+
+
+@dataclass
+class Optional:
+ """Simple relation qualifier."""
+
+ relation: str
+
+
+class AsyncNodeSet(AsyncBaseSet):
+ """
+ A class representing as set of nodes matching common query parameters
+ """
+
+ def __init__(self, source):
+ self.source = source # could be a Traverse object or a node class
+ if isinstance(source, AsyncTraversal):
+ self.source_class = source.target_class
+ elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode):
+ self.source_class = source
+ elif isinstance(source, AsyncStructuredNode):
+ self.source_class = source.__class__
+ else:
+ raise ValueError("Bad source for nodeset " + repr(source))
+
+ # setup Traversal objects using relationship definitions
+ install_traversals(self.source_class, self)
+
+ self.filters = []
+ self.q_filters = Q()
+
+ # used by has()
+ self.must_match = {}
+ self.dont_match = {}
+
+ self.relations_to_fetch: list = []
+
+ def __await__(self):
+ return self.all().__await__()
+
+ async def _get(self, limit=None, lazy=False, **kwargs):
+ self.filter(**kwargs)
+ if limit:
+ self.limit = limit
+ ast = await self.query_cls(self).build_ast()
+ return await ast._execute(lazy)
+
+ async def get(self, lazy=False, **kwargs):
+ """
+ Retrieve one node from the set matching supplied parameters
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :param kwargs: same syntax as `filter()`
+ :return: node
+ """
+ result = await self._get(limit=2, lazy=lazy, **kwargs)
+ if len(result) > 1:
+ raise MultipleNodesReturned(repr(kwargs))
+ if not result:
+ raise self.source_class.DoesNotExist(repr(kwargs))
+ return result[0]
+
+ async def get_or_none(self, **kwargs):
+ """
+ Retrieve a node from the set matching supplied parameters or return none
+
+ :param kwargs: same syntax as `filter()`
+ :return: node or none
+ """
+ try:
+ return await self.get(**kwargs)
+ except self.source_class.DoesNotExist:
+ return None
+
+ async def first(self, **kwargs):
+ """
+ Retrieve the first node from the set matching supplied parameters
+
+ :param kwargs: same syntax as `filter()`
+ :return: node
+ """
+ result = await self._get(limit=1, **kwargs)
+ if result:
+ return result[0]
+ else:
+ raise self.source_class.DoesNotExist(repr(kwargs))
+
+ async def first_or_none(self, **kwargs):
+ """
+ Retrieve the first node from the set matching supplied parameters or return none
+
+ :param kwargs: same syntax as `filter()`
+ :return: node or none
+ """
+ try:
+ return await self.first(**kwargs)
+ except self.source_class.DoesNotExist:
+ pass
+ return None
+
+ def filter(self, *args, **kwargs):
+ """
+ Apply filters to the existing nodes in the set.
+
+ :param args: a Q object
+
+ e.g `.filter(Q(salary__lt=10000) | Q(salary__gt=20000))`.
+
+ :param kwargs: filter parameters
+
+ Filters mimic Django's syntax with the double '__' to separate field and operators.
+
+ e.g `.filter(salary__gt=20000)` results in `salary > 20000`.
+
+ The following operators are available:
+
+ * 'lt': less than
+ * 'gt': greater than
+ * 'lte': less than or equal to
+ * 'gte': greater than or equal to
+ * 'ne': not equal to
+ * 'in': matches one of list (or tuple)
+ * 'isnull': is null
+ * 'regex': matches supplied regex (neo4j regex format)
+ * 'exact': exactly match string (just '=')
+ * 'iexact': case insensitive match string
+ * 'contains': contains string
+ * 'icontains': case insensitive contains
+ * 'startswith': string starts with
+ * 'istartswith': case insensitive string starts with
+ * 'endswith': string ends with
+ * 'iendswith': case insensitive string ends with
+
+ :return: self
+ """
+ if args or kwargs:
+ self.q_filters = Q(self.q_filters & Q(*args, **kwargs))
+ return self
+
+ def exclude(self, *args, **kwargs):
+ """
+ Exclude nodes from the NodeSet via filters.
+
+ :param kwargs: filter parameters see syntax for the filter method
+ :return: self
+ """
+ if args or kwargs:
+ self.q_filters = Q(self.q_filters & ~Q(*args, **kwargs))
+ return self
+
+ def has(self, **kwargs):
+ must_match, dont_match = process_has_args(self.source_class, kwargs)
+ self.must_match.update(must_match)
+ self.dont_match.update(dont_match)
+ return self
+
+ def order_by(self, *props):
+ """
+ Order by properties. Prepend with minus to do descending. Pass None to
+ remove ordering.
+ """
+ should_remove = len(props) == 1 and props[0] is None
+ if not hasattr(self, "order_by_elements") or should_remove:
+ self.order_by_elements = []
+ if should_remove:
+ return self
+ if "?" in props:
+ self.order_by_elements.append("?")
+ else:
+ for prop in props:
+ prop = prop.strip()
+ if prop.startswith("-"):
+ prop = prop[1:]
+ desc = True
+ else:
+ desc = False
+
+ if prop not in self.source_class.defined_properties(rels=False):
+ raise ValueError(
+ f"No such property {prop} on {self.source_class.__name__}. Note that Neo4j internals like id or element_id are not allowed for use in this operation."
+ )
+
+ property_obj = getattr(self.source_class, prop)
+ if isinstance(property_obj, AliasProperty):
+ prop = property_obj.aliased_to()
+
+ self.order_by_elements.append(prop + (" DESC" if desc else ""))
+
+ return self
+
+ def fetch_relations(self, *relation_names):
+ """Specify a set of relations to return."""
+ relations = []
+ for relation_name in relation_names:
+ if isinstance(relation_name, Optional):
+ item = {"path": relation_name.relation, "optional": True}
+ else:
+ item = {"path": relation_name}
+ relations.append(item)
+ self.relations_to_fetch = relations
+ return self
+
+
+class AsyncTraversal(AsyncBaseSet):
+ """
+ Models a traversal from a node to another.
+
+ :param source: Starting of the traversal.
+ :type source: A :class:`~neomodel.core.StructuredNode` subclass, an
+ instance of such, a :class:`~neomodel.match.NodeSet` instance
+ or a :class:`~neomodel.match.Traversal` instance.
+ :param name: A name for the traversal.
+ :type name: :class:`str`
+ :param definition: A relationship definition that most certainly deserves
+ a documentation here.
+ :type defintion: :class:`dict`
+ """
+
+ def __await__(self):
+ return self.all().__await__()
+
+ def __init__(self, source, name, definition):
+ """
+ Create a traversal
+
+ """
+ self.source = source
+
+ if isinstance(source, AsyncTraversal):
+ self.source_class = source.target_class
+ elif inspect.isclass(source) and issubclass(source, AsyncStructuredNode):
+ self.source_class = source
+ elif isinstance(source, AsyncStructuredNode):
+ self.source_class = source.__class__
+ elif isinstance(source, AsyncNodeSet):
+ self.source_class = source.source_class
+ else:
+ raise TypeError(f"Bad source for traversal: {type(source)}")
+
+ invalid_keys = set(definition) - {
+ "direction",
+ "model",
+ "node_class",
+ "relation_type",
+ }
+ if invalid_keys:
+ raise ValueError(f"Prohibited keys in Traversal definition: {invalid_keys}")
+
+ self.definition = definition
+ self.target_class = definition["node_class"]
+ self.name = name
+ self.filters = []
+
+ def match(self, **kwargs):
+ """
+ Traverse relationships with properties matching the given parameters.
+
+ e.g: `.match(price__lt=10)`
+
+ :param kwargs: see `NodeSet.filter()` for syntax
+ :return: self
+ """
+ if kwargs:
+ if self.definition.get("model") is None:
+ raise ValueError(
+ "match() with filter only available on relationships with a model"
+ )
+ output = process_filter_args(self.definition["model"], kwargs)
+ if output:
+ self.filters.append(output)
+ return self
diff --git a/neomodel/async_/path.py b/neomodel/async_/path.py
new file mode 100644
index 00000000..6128347e
--- /dev/null
+++ b/neomodel/async_/path.py
@@ -0,0 +1,53 @@
+from neo4j.graph import Path
+
+from neomodel.async_.core import adb
+from neomodel.async_.relationship import AsyncStructuredRel
+
+
+class AsyncNeomodelPath(Path):
+ """
+ Represents paths within neomodel.
+
+ This object is instantiated when you include whole paths in your ``cypher_query()``
+ result sets and turn ``resolve_objects`` to True.
+
+ That is, any query of the form:
+ ::
+
+ MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p
+
+ ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already
+ resolved to their neomodel objects if such mapping is possible.
+
+
+ :param nodes: Neomodel nodes appearing in the path in order of appearance.
+ :param relationships: Neomodel relationships appearing in the path in order of appearance.
+ :type nodes: List[StructuredNode]
+ :type relationships: List[StructuredRel]
+ """
+
+ def __init__(self, a_neopath):
+ self._nodes = []
+ self._relationships = []
+
+ for a_node in a_neopath.nodes:
+ self._nodes.append(adb._object_resolution(a_node))
+
+ for a_relationship in a_neopath.relationships:
+ # This check is required here because if the relationship does not bear data
+ # then it does not have an entry in the registry. In that case, we instantiate
+ # an "unspecified" StructuredRel.
+ rel_type = frozenset([a_relationship.type])
+ if rel_type in adb._NODE_CLASS_REGISTRY:
+ new_rel = adb._object_resolution(a_relationship)
+ else:
+ new_rel = AsyncStructuredRel.inflate(a_relationship)
+ self._relationships.append(new_rel)
+
+ @property
+ def nodes(self):
+ return self._nodes
+
+ @property
+ def relationships(self):
+ return self._relationships
diff --git a/neomodel/async_/property_manager.py b/neomodel/async_/property_manager.py
new file mode 100644
index 00000000..b9401dab
--- /dev/null
+++ b/neomodel/async_/property_manager.py
@@ -0,0 +1,109 @@
+import types
+
+from neomodel.exceptions import RequiredProperty
+from neomodel.properties import AliasProperty, Property
+
+
+def display_for(key):
+ def display_choice(self):
+ return getattr(self.__class__, key).choices[getattr(self, key)]
+
+ return display_choice
+
+
+class AsyncPropertyManager:
+ """
+ Common methods for handling properties on node and relationship objects.
+ """
+
+ def __init__(self, **kwargs):
+ properties = getattr(self, "__all_properties__", None)
+ if properties is None:
+ properties = self.defined_properties(rels=False, aliases=False).items()
+ for name, property in properties:
+ if kwargs.get(name) is None:
+ if getattr(property, "has_default", False):
+ setattr(self, name, property.default_value())
+ else:
+ setattr(self, name, None)
+ else:
+ setattr(self, name, kwargs[name])
+
+ if getattr(property, "choices", None):
+ setattr(
+ self,
+ f"get_{name}_display",
+ types.MethodType(display_for(name), self),
+ )
+
+ if name in kwargs:
+ del kwargs[name]
+
+ aliases = getattr(self, "__all_aliases__", None)
+ if aliases is None:
+ aliases = self.defined_properties(
+ aliases=True, rels=False, properties=False
+ ).items()
+ for name, property in aliases:
+ if name in kwargs:
+ setattr(self, name, kwargs[name])
+ del kwargs[name]
+
+ # undefined properties (for magic @prop.setters etc)
+ for name, property in kwargs.items():
+ setattr(self, name, property)
+
+ @property
+ def __properties__(self):
+ from neomodel.async_.relationship_manager import AsyncRelationshipManager
+
+ return dict(
+ (name, value)
+ for name, value in vars(self).items()
+ if not name.startswith("_")
+ and not callable(value)
+ and not isinstance(
+ value,
+ (
+ AsyncRelationshipManager,
+ AliasProperty,
+ ),
+ )
+ )
+
+ @classmethod
+ def deflate(cls, properties, obj=None, skip_empty=False):
+ # deflate dict ready to be stored
+ deflated = {}
+ for name, property in cls.defined_properties(aliases=False, rels=False).items():
+ db_property = property.db_property or name
+ if properties.get(name) is not None:
+ deflated[db_property] = property.deflate(properties[name], obj)
+ elif property.has_default:
+ deflated[db_property] = property.deflate(property.default_value(), obj)
+ elif property.required:
+ raise RequiredProperty(name, cls)
+ elif not skip_empty:
+ deflated[db_property] = None
+ return deflated
+
+ @classmethod
+ def defined_properties(cls, aliases=True, properties=True, rels=True):
+ from neomodel.async_.relationship_manager import AsyncRelationshipDefinition
+
+ props = {}
+ for baseclass in reversed(cls.__mro__):
+ props.update(
+ dict(
+ (name, property)
+ for name, property in vars(baseclass).items()
+ if (aliases and isinstance(property, AliasProperty))
+ or (
+ properties
+ and isinstance(property, Property)
+ and not isinstance(property, AliasProperty)
+ )
+ or (rels and isinstance(property, AsyncRelationshipDefinition))
+ )
+ )
+ return props
diff --git a/neomodel/async_/relationship.py b/neomodel/async_/relationship.py
new file mode 100644
index 00000000..5653137a
--- /dev/null
+++ b/neomodel/async_/relationship.py
@@ -0,0 +1,168 @@
+from neomodel.async_.core import adb
+from neomodel.async_.property_manager import AsyncPropertyManager
+from neomodel.hooks import hooks
+from neomodel.properties import Property
+
+ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
+
+
+class RelationshipMeta(type):
+ def __new__(mcs, name, bases, dct):
+ inst = super().__new__(mcs, name, bases, dct)
+ for key, value in dct.items():
+ if issubclass(value.__class__, Property):
+ if key == "source" or key == "target":
+ raise ValueError(
+ "Property names 'source' and 'target' are not allowed as they conflict with neomodel internals."
+ )
+ elif key == "id":
+ raise ValueError(
+ """
+ Property name 'id' is not allowed as it conflicts with neomodel internals.
+ Consider using 'uid' or 'identifier' as id is also a Neo4j internal.
+ """
+ )
+ elif key == "element_id":
+ raise ValueError(
+ """
+ Property name 'element_id' is not allowed as it conflicts with neomodel internals.
+ Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal.
+ """
+ )
+ value.name = key
+ value.owner = inst
+
+ # support for 'magic' properties
+ if hasattr(value, "setup") and hasattr(value.setup, "__call__"):
+ value.setup()
+ return inst
+
+
+StructuredRelBase = RelationshipMeta("RelationshipBase", (AsyncPropertyManager,), {})
+
+
+class AsyncStructuredRel(StructuredRelBase):
+ """
+ Base class for relationship objects
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @property
+ def element_id(self):
+ if hasattr(self, "element_id_property"):
+ return self.element_id_property
+
+ @property
+ def _start_node_element_id(self):
+ if hasattr(self, "_start_node_element_id_property"):
+ return self._start_node_element_id_property
+
+ @property
+ def _end_node_element_id(self):
+ if hasattr(self, "_end_node_element_id_property"):
+ return self._end_node_element_id_property
+
+ # Version 4.4 support - id is deprecated in version 5.x
+ @property
+ def id(self):
+ try:
+ return int(self.element_id_property)
+ except (TypeError, ValueError) as exc:
+ raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc
+
+ # Version 4.4 support - id is deprecated in version 5.x
+ @property
+ def _start_node_id(self):
+ try:
+ return int(self._start_node_element_id_property)
+ except (TypeError, ValueError) as exc:
+ raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc
+
+ # Version 4.4 support - id is deprecated in version 5.x
+ @property
+ def _end_node_id(self):
+ try:
+ return int(self._end_node_element_id_property)
+ except (TypeError, ValueError) as exc:
+ raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc
+
+ @hooks
+ async def save(self):
+ """
+ Save the relationship
+
+ :return: self
+ """
+ props = self.deflate(self.__properties__)
+ query = f"MATCH ()-[r]->() WHERE {await adb.get_id_method()}(r)=$self "
+ query += "".join([f" SET r.{key} = ${key}" for key in props])
+ props["self"] = await adb.parse_element_id(self.element_id)
+
+ await adb.cypher_query(query, props)
+
+ return self
+
+ async def start_node(self):
+ """
+ Get start node
+
+ :return: StructuredNode
+ """
+ results = await adb.cypher_query(
+ f"""
+ MATCH (aNode)
+ WHERE {await adb.get_id_method()}(aNode)=$start_node_element_id
+ RETURN aNode
+ """,
+ {
+ "start_node_element_id": await adb.parse_element_id(
+ self._start_node_element_id
+ )
+ },
+ resolve_objects=True,
+ )
+ return results[0][0][0]
+
+ async def end_node(self):
+ """
+ Get end node
+
+ :return: StructuredNode
+ """
+ results = await adb.cypher_query(
+ f"""
+ MATCH (aNode)
+ WHERE {await adb.get_id_method()}(aNode)=$end_node_element_id
+ RETURN aNode
+ """,
+ {
+ "end_node_element_id": await adb.parse_element_id(
+ self._end_node_element_id
+ )
+ },
+ resolve_objects=True,
+ )
+ return results[0][0][0]
+
+ @classmethod
+ def inflate(cls, rel):
+ """
+ Inflate a neo4j_driver relationship object to a neomodel object
+ :param rel:
+ :return: StructuredRel
+ """
+ props = {}
+ for key, prop in cls.defined_properties(aliases=False, rels=False).items():
+ if key in rel:
+ props[key] = prop.inflate(rel[key], obj=rel)
+ elif prop.has_default:
+ props[key] = prop.default_value()
+ else:
+ props[key] = None
+ srel = cls(**props)
+ srel._start_node_element_id_property = rel.start_node.element_id
+ srel._end_node_element_id_property = rel.end_node.element_id
+ srel.element_id_property = rel.element_id
+ return srel
diff --git a/neomodel/async_/relationship_manager.py b/neomodel/async_/relationship_manager.py
new file mode 100644
index 00000000..9c5e7398
--- /dev/null
+++ b/neomodel/async_/relationship_manager.py
@@ -0,0 +1,561 @@
+import functools
+import inspect
+import sys
+from importlib import import_module
+
+from neomodel.async_.core import adb
+from neomodel.async_.match import (
+ AsyncNodeSet,
+ AsyncTraversal,
+ _rel_helper,
+ _rel_merge_helper,
+)
+from neomodel.async_.relationship import AsyncStructuredRel
+from neomodel.exceptions import NotConnected, RelationshipClassRedefined
+from neomodel.util import (
+ EITHER,
+ INCOMING,
+ OUTGOING,
+ _get_node_properties,
+ enumerate_traceback,
+)
+
+# basestring python 3.x fallback
+try:
+ basestring
+except NameError:
+ basestring = str
+
+
+# check source node is saved and not deleted
+def check_source(fn):
+ fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__
+
+ @functools.wraps(fn)
+ def checker(self, *args, **kwargs):
+ self.source._pre_action_check(self.name + "." + fn_name)
+ return fn(self, *args, **kwargs)
+
+ return checker
+
+
+# checks if obj is a direct subclass, 1 level
+def is_direct_subclass(obj, classinfo):
+ for base in obj.__bases__:
+ if base == classinfo:
+ return True
+ return False
+
+
+class AsyncRelationshipManager(object):
+ """
+ Base class for all relationships managed through neomodel.
+
+ I.e the 'friends' object in `user.friends.all()`
+ """
+
+ def __init__(self, source, key, definition):
+ self.source = source
+ self.source_class = source.__class__
+ self.name = key
+ self.definition = definition
+
+ def __str__(self):
+ direction = "either"
+ if self.definition["direction"] == OUTGOING:
+ direction = "a outgoing"
+ elif self.definition["direction"] == INCOMING:
+ direction = "a incoming"
+
+ return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'"
+
+ def __await__(self):
+ return self.all().__await__()
+
+ def _check_node(self, obj):
+ """check for valid node i.e correct class and is saved"""
+ if not issubclass(type(obj), self.definition["node_class"]):
+ raise ValueError(
+ "Expected node of class " + self.definition["node_class"].__name__
+ )
+ if not hasattr(obj, "element_id"):
+ raise ValueError("Can't perform operation on unsaved node " + repr(obj))
+
+ @check_source
+ async def connect(self, node, properties=None):
+ """
+ Connect a node
+
+ :param node:
+ :param properties: for the new relationship
+ :type: dict
+ :return:
+ """
+ self._check_node(node)
+
+ if not self.definition["model"] and properties:
+ raise NotImplementedError(
+ "Relationship properties without using a relationship model "
+ "is no longer supported."
+ )
+
+ params = {}
+ rel_model = self.definition["model"]
+ rel_prop = None
+
+ if rel_model:
+ rel_prop = {}
+ # need to generate defaults etc to create fake instance
+ tmp = rel_model(**properties) if properties else rel_model()
+ # build params and place holders to pass to rel_helper
+ for prop, val in rel_model.deflate(tmp.__properties__).items():
+ if val is not None:
+ rel_prop[prop] = "$" + prop
+ else:
+ rel_prop[prop] = None
+ params[prop] = val
+
+ if hasattr(tmp, "pre_save"):
+ tmp.pre_save()
+
+ new_rel = _rel_merge_helper(
+ lhs="us",
+ rhs="them",
+ ident="r",
+ relation_properties=rel_prop,
+ **self.definition,
+ )
+ q = (
+ f"MATCH (them), (us) WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self "
+ "MERGE" + new_rel
+ )
+
+ params["them"] = await adb.parse_element_id(node.element_id)
+
+ if not rel_model:
+ await self.source.cypher(q, params)
+ return True
+
+ results = await self.source.cypher(q + " RETURN r", params)
+ rel_ = results[0][0][0]
+ rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node)
+
+ if hasattr(rel_instance, "post_save"):
+ rel_instance.post_save()
+
+ return rel_instance
+
+ @check_source
+ async def replace(self, node, properties=None):
+ """
+ Disconnect all existing nodes and connect the supplied node
+
+ :param node:
+ :param properties: for the new relationship
+ :type: dict
+ :return:
+ """
+ await self.disconnect_all()
+ await self.connect(node, properties)
+
+ @check_source
+ async def relationship(self, node):
+ """
+ Retrieve the relationship object for this first relationship between self and node.
+
+ :param node:
+ :return: StructuredRel
+ """
+ self._check_node(node)
+ my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition)
+ q = (
+ "MATCH "
+ + my_rel
+ + f" WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r LIMIT 1"
+ )
+ results = await self.source.cypher(
+ q, {"them": await adb.parse_element_id(node.element_id)}
+ )
+ rels = results[0]
+ if not rels:
+ return
+
+ rel_model = self.definition.get("model") or AsyncStructuredRel
+
+ return self._set_start_end_cls(rel_model.inflate(rels[0][0]), node)
+
+ @check_source
+ async def all_relationships(self, node):
+ """
+ Retrieve all relationship objects between self and node.
+
+ :param node:
+ :return: [StructuredRel]
+ """
+ self._check_node(node)
+
+ my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition)
+ q = f"MATCH {my_rel} WHERE {await adb.get_id_method()}(them)=$them and {await adb.get_id_method()}(us)=$self RETURN r "
+ results = await self.source.cypher(
+ q, {"them": await adb.parse_element_id(node.element_id)}
+ )
+ rels = results[0]
+ if not rels:
+ return []
+
+ rel_model = self.definition.get("model") or AsyncStructuredRel
+ return [
+ self._set_start_end_cls(rel_model.inflate(rel[0]), node) for rel in rels
+ ]
+
+ def _set_start_end_cls(self, rel_instance, obj):
+ if self.definition["direction"] == INCOMING:
+ rel_instance._start_node_class = obj.__class__
+ rel_instance._end_node_class = self.source_class
+ else:
+ rel_instance._start_node_class = self.source_class
+ rel_instance._end_node_class = obj.__class__
+ return rel_instance
+
+ @check_source
+ async def reconnect(self, old_node, new_node):
+ """
+ Disconnect old_node and connect new_node copying over any properties on the original relationship.
+
+ Useful for preventing cardinality violations
+
+ :param old_node:
+ :param new_node:
+ :return: None
+ """
+
+ self._check_node(old_node)
+ self._check_node(new_node)
+ if old_node.element_id == new_node.element_id:
+ return
+ old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition)
+
+ # get list of properties on the existing rel
+ old_node_element_id = await adb.parse_element_id(old_node.element_id)
+ new_node_element_id = await adb.parse_element_id(new_node.element_id)
+ result, _ = await self.source.cypher(
+ f"""
+ MATCH (us), (old) WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old
+ MATCH {old_rel} RETURN r
+ """,
+ {"old": old_node_element_id},
+ )
+ if result:
+ node_properties = _get_node_properties(result[0][0])
+ existing_properties = node_properties.keys()
+ else:
+ raise NotConnected("reconnect", self.source, old_node)
+
+ # remove old relationship and create new one
+ new_rel = _rel_merge_helper(lhs="us", rhs="new", ident="r2", **self.definition)
+ q = (
+ "MATCH (us), (old), (new) "
+ f"WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(old)=$old and {await adb.get_id_method()}(new)=$new "
+ "MATCH " + old_rel
+ )
+ q += " MERGE" + new_rel
+
+ # copy over properties if we have
+ q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties])
+ q += " WITH r DELETE r"
+
+ await self.source.cypher(
+ q, {"old": old_node_element_id, "new": new_node_element_id}
+ )
+
+ @check_source
+ async def disconnect(self, node):
+ """
+ Disconnect a node
+
+ :param node:
+ :return:
+ """
+ rel = _rel_helper(lhs="a", rhs="b", ident="r", **self.definition)
+ q = f"""
+ MATCH (a), (b) WHERE {await adb.get_id_method()}(a)=$self and {await adb.get_id_method()}(b)=$them
+ MATCH {rel} DELETE r
+ """
+ await self.source.cypher(
+ q, {"them": await adb.parse_element_id(node.element_id)}
+ )
+
+ @check_source
+ async def disconnect_all(self):
+ """
+ Disconnect all nodes
+
+ :return:
+ """
+ rhs = "b:" + self.definition["node_class"].__label__
+ rel = _rel_helper(lhs="a", rhs=rhs, ident="r", **self.definition)
+ q = (
+ f"MATCH (a) WHERE {await adb.get_id_method()}(a)=$self MATCH "
+ + rel
+ + " DELETE r"
+ )
+ await self.source.cypher(q)
+
+ @check_source
+ def _new_traversal(self):
+ return AsyncTraversal(self.source, self.name, self.definition)
+
+ # The methods below simply proxy the match engine.
+ def get(self, **kwargs):
+ """
+ Retrieve a related node with the matching node properties.
+
+ :param kwargs: same syntax as `NodeSet.filter()`
+ :return: node
+ """
+ return AsyncNodeSet(self._new_traversal()).get(**kwargs)
+
+ def get_or_none(self, **kwargs):
+ """
+ Retrieve a related node with the matching node properties or return None.
+
+ :param kwargs: same syntax as `NodeSet.filter()`
+ :return: node
+ """
+ return AsyncNodeSet(self._new_traversal()).get_or_none(**kwargs)
+
+ def filter(self, *args, **kwargs):
+ """
+ Retrieve related nodes matching the provided properties.
+
+ :param args: a Q object
+ :param kwargs: same syntax as `NodeSet.filter()`
+ :return: NodeSet
+ """
+ return AsyncNodeSet(self._new_traversal()).filter(*args, **kwargs)
+
+ def order_by(self, *props):
+ """
+ Order related nodes by specified properties
+
+ :param props:
+ :return: NodeSet
+ """
+ return AsyncNodeSet(self._new_traversal()).order_by(*props)
+
+ def exclude(self, *args, **kwargs):
+ """
+ Exclude nodes that match the provided properties.
+
+ :param args: a Q object
+ :param kwargs: same syntax as `NodeSet.filter()`
+ :return: NodeSet
+ """
+ return AsyncNodeSet(self._new_traversal()).exclude(*args, **kwargs)
+
+ async def is_connected(self, node):
+ """
+ Check if a node is connected with this relationship type
+ :param node:
+ :return: bool
+ """
+ return await self._new_traversal().check_contains(node)
+
+ async def single(self):
+ """
+ Get a single related node or none.
+
+ :return: StructuredNode
+ """
+ try:
+ rels = await self
+ return rels[0]
+ except IndexError:
+ pass
+
+ def match(self, **kwargs):
+ """
+ Return set of nodes who's relationship properties match supplied args
+
+ :param kwargs: same syntax as `NodeSet.filter()`
+ :return: NodeSet
+ """
+ return self._new_traversal().match(**kwargs)
+
+ async def all(self):
+ """
+ Return all related nodes.
+
+ :return: list
+ """
+ return await self._new_traversal().all()
+
+ async def __aiter__(self):
+ return self._new_traversal().__aiter__()
+
+ async def get_len(self):
+ return await self._new_traversal().get_len()
+
+ async def check_bool(self):
+ return await self._new_traversal().check_bool()
+
+ async def check_nonzero(self):
+ return self._new_traversal().check_nonzero()
+
+ async def check_contains(self, obj):
+ return self._new_traversal().check_contains(obj)
+
+ async def get_item(self, key):
+ return self._new_traversal().get_item(key)
+
+
+class AsyncRelationshipDefinition:
+ def __init__(
+ self,
+ relation_type,
+ cls_name,
+ direction,
+ manager=AsyncRelationshipManager,
+ model=None,
+ ):
+ self._validate_class(cls_name, model)
+
+ current_frame = inspect.currentframe()
+
+ frame_number = 3
+ for i, frame in enumerate_traceback(current_frame):
+ if cls_name in frame.f_globals:
+ frame_number = i
+ break
+ self.module_name = sys._getframe(frame_number).f_globals["__name__"]
+ if "__file__" in sys._getframe(frame_number).f_globals:
+ self.module_file = sys._getframe(frame_number).f_globals["__file__"]
+ self._raw_class = cls_name
+ self.manager = manager
+ self.definition = {
+ "relation_type": relation_type,
+ "direction": direction,
+ "model": model,
+ }
+
+ if model is not None:
+ # Relationships are easier to instantiate because
+ # they cannot have multiple labels.
+ # So, a relationship's type determines the class that should be
+ # instantiated uniquely.
+ # Here however, we still use a `frozenset([relation_type])`
+ # to preserve the mapping type.
+ label_set = frozenset([relation_type])
+ try:
+ # If the relationship mapping exists then it is attempted
+ # to be redefined so that it applies to the same label.
+ # In this case, it has to be ensured that the class
+ # that is overriding the relationship is a descendant
+ # of the already existing class.
+ model_from_registry = adb._NODE_CLASS_REGISTRY[label_set]
+ if not issubclass(model, model_from_registry):
+ is_parent = issubclass(model_from_registry, model)
+ if is_direct_subclass(model, AsyncStructuredRel) and not is_parent:
+ raise RelationshipClassRedefined(
+ relation_type, adb._NODE_CLASS_REGISTRY, model
+ )
+ else:
+ adb._NODE_CLASS_REGISTRY[label_set] = model
+ except KeyError:
+ # If the mapping does not exist then it is simply created.
+ adb._NODE_CLASS_REGISTRY[label_set] = model
+
+ def _validate_class(self, cls_name, model):
+ if not isinstance(cls_name, (basestring, object)):
+ raise ValueError("Expected class name or class got " + repr(cls_name))
+
+ if model and not issubclass(model, (AsyncStructuredRel,)):
+ raise ValueError("model must be a StructuredRel")
+
+ def lookup_node_class(self):
+ if not isinstance(self._raw_class, basestring):
+ self.definition["node_class"] = self._raw_class
+ else:
+ name = self._raw_class
+ if name.find(".") == -1:
+ module = self.module_name
+ else:
+ module, _, name = name.rpartition(".")
+
+ if module not in sys.modules:
+ # yet another hack to get around python semantics
+ # __name__ is the namespace of the parent module for __init__.py files,
+ # and the namespace of the current module for other .py files,
+ # therefore there's a need to define the namespace differently for
+ # these two cases in order for . in relative imports to work correctly
+ # (i.e. to mean the same thing for both cases).
+ # For example in the comments below, namespace == myapp, always
+ if not hasattr(self, "module_file"):
+ raise ImportError(f"Couldn't lookup '{name}'")
+
+ if "__init__.py" in self.module_file:
+ # e.g. myapp/__init__.py -[__name__]-> myapp
+ namespace = self.module_name
+ else:
+ # e.g. myapp/models.py -[__name__]-> myapp.models
+ namespace = self.module_name.rpartition(".")[0]
+
+ # load a module from a namespace (e.g. models from myapp)
+ if module:
+ module = import_module(module, namespace).__name__
+ # load the namespace itself (e.g. myapp)
+ # (otherwise it would look like import . from myapp)
+ else:
+ module = import_module(namespace).__name__
+ self.definition["node_class"] = getattr(sys.modules[module], name)
+
+ def build_manager(self, source, name):
+ self.lookup_node_class()
+ return self.manager(source, name, self.definition)
+
+
+class AsyncZeroOrMore(AsyncRelationshipManager):
+ """
+ A relationship of zero or more nodes (the default)
+ """
+
+ description = "zero or more relationships"
+
+
+class AsyncRelationshipTo(AsyncRelationshipDefinition):
+ def __init__(
+ self,
+ cls_name,
+ relation_type,
+ cardinality=AsyncZeroOrMore,
+ model=None,
+ ):
+ super().__init__(
+ relation_type, cls_name, OUTGOING, manager=cardinality, model=model
+ )
+
+
+class AsyncRelationshipFrom(AsyncRelationshipDefinition):
+ def __init__(
+ self,
+ cls_name,
+ relation_type,
+ cardinality=AsyncZeroOrMore,
+ model=None,
+ ):
+ super().__init__(
+ relation_type, cls_name, INCOMING, manager=cardinality, model=model
+ )
+
+
+class AsyncRelationship(AsyncRelationshipDefinition):
+ def __init__(
+ self,
+ cls_name,
+ relation_type,
+ cardinality=AsyncZeroOrMore,
+ model=None,
+ ):
+ super().__init__(
+ relation_type, cls_name, EITHER, manager=cardinality, model=model
+ )
diff --git a/neomodel/config.py b/neomodel/config.py
index b54aa806..85c9ed8a 100644
--- a/neomodel/config.py
+++ b/neomodel/config.py
@@ -1,8 +1,6 @@
import neo4j
-from ._version import __version__
-
-AUTO_INSTALL_LABELS = False
+from neomodel._version import __version__
# Use this to connect with automatically created driver
# The following options are the default ones that will be used as driver config
@@ -24,3 +22,6 @@
# DRIVER = neo4j.GraphDatabase().driver(
# "bolt://localhost:7687", auth=("neo4j", "foobarbaz")
# )
+DRIVER = None
+# Use this to connect to a specific database when using the self-managed driver
+DATABASE_NAME = None
diff --git a/neomodel/contrib/__init__.py b/neomodel/contrib/__init__.py
index 3be00b41..a852965d 100644
--- a/neomodel/contrib/__init__.py
+++ b/neomodel/contrib/__init__.py
@@ -1 +1,2 @@
-from .semi_structured import SemiStructuredNode
+from neomodel.contrib.async_.semi_structured import AsyncSemiStructuredNode
+from neomodel.contrib.sync_.semi_structured import SemiStructuredNode
diff --git a/neomodel/contrib/async_/semi_structured.py b/neomodel/contrib/async_/semi_structured.py
new file mode 100644
index 00000000..c333ae0e
--- /dev/null
+++ b/neomodel/contrib/async_/semi_structured.py
@@ -0,0 +1,64 @@
+from neomodel.async_.core import AsyncStructuredNode
+from neomodel.exceptions import DeflateConflict, InflateConflict
+from neomodel.util import _get_node_properties
+
+
+class AsyncSemiStructuredNode(AsyncStructuredNode):
+ """
+ A base class allowing properties to be stored on a node that aren't
+ specified in its definition. Conflicting properties are signaled with the
+ :class:`DeflateConflict` exception::
+
+ class Person(AsyncSemiStructuredNode):
+ name = StringProperty()
+ age = IntegerProperty()
+
+ def hello(self):
+ print("Hi my names " + self.name)
+
+ tim = await Person(name='Tim', age=8, weight=11).save()
+ tim.hello = "Hi"
+ await tim.save() # DeflateConflict
+ """
+
+ __abstract_node__ = True
+
+ @classmethod
+ def inflate(cls, node):
+ # support lazy loading
+ if isinstance(node, str) or isinstance(node, int):
+ snode = cls()
+ snode.element_id_property = node
+ else:
+ props = {}
+ node_properties = {}
+ for key, prop in cls.__all_properties__:
+ node_properties = _get_node_properties(node)
+ if key in node_properties:
+ props[key] = prop.inflate(node_properties[key], node)
+ elif prop.has_default:
+ props[key] = prop.default_value()
+ else:
+ props[key] = None
+ # handle properties not defined on the class
+ for free_key in (x for x in node_properties if x not in props):
+ if hasattr(cls, free_key):
+ raise InflateConflict(
+ cls, free_key, node_properties[free_key], node.element_id
+ )
+ props[free_key] = node_properties[free_key]
+
+ snode = cls(**props)
+ snode.element_id_property = node.element_id
+
+ return snode
+
+ @classmethod
+ def deflate(cls, node_props, obj=None, skip_empty=False):
+ deflated = super().deflate(node_props, obj, skip_empty=skip_empty)
+ for key in [k for k in node_props if k not in deflated]:
+ if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty):
+ raise DeflateConflict(cls, key, deflated[key], obj.element_id)
+
+ node_props.update(deflated)
+ return node_props
diff --git a/neomodel/contrib/spatial_properties.py b/neomodel/contrib/spatial_properties.py
index 7a48018b..982e6921 100644
--- a/neomodel/contrib/spatial_properties.py
+++ b/neomodel/contrib/spatial_properties.py
@@ -25,9 +25,9 @@
# If shapely is not installed, its import will fail and the spatial properties will not be available
try:
- from shapely.geometry import Point as ShapelyPoint
from shapely import __version__ as shapely_version
- from shapely.coords import CoordinateSequence
+ from shapely.coords import CoordinateSequence
+ from shapely.geometry import Point as ShapelyPoint
except ImportError as exc:
raise ImportError(
"NEOMODEL ERROR: Shapely not found. If required, you can install Shapely via "
@@ -53,10 +53,11 @@
# Taking into account the Shapely 2.0.0 changes in the way POINT objects are initialisd.
if int("".join(shapely_version.split(".")[0:3])) < 200:
+
class NeomodelPoint(ShapelyPoint):
"""
Abstracts the Point spatial data type of Neo4j.
-
+
Note:
At the time of writing, Neo4j supports 2 main variants of Point:
1. A generic point defined over a Cartesian plane
@@ -65,12 +66,12 @@ class NeomodelPoint(ShapelyPoint):
* The minimum data to define a point is longitude, latitude [,Height] and the crs is then assumed
to be "wgs-84".
"""
-
+
# def __init__(self, *args, crs=None, x=None, y=None, z=None, latitude=None, longitude=None, height=None, **kwargs):
def __init__(self, *args, **kwargs):
"""
Creates a NeomodelPoint.
-
+
:param args: Positional arguments to emulate the behaviour of Shapely's Point (and specifically the copy
constructor)
:type args: list
@@ -91,7 +92,7 @@ def __init__(self, *args, **kwargs):
:param kwargs: Dictionary of keyword arguments
:type kwargs: dict
"""
-
+
# Python2.7 Workaround for the order that the arguments get passed to the functions
crs = kwargs.pop("crs", None)
x = kwargs.pop("x", None)
@@ -100,14 +101,16 @@ def __init__(self, *args, **kwargs):
longitude = kwargs.pop("longitude", None)
latitude = kwargs.pop("latitude", None)
height = kwargs.pop("height", None)
-
+
_x, _y, _z = None, None, None
-
+
# CRS validity check is common to both types of constructors that follow
if crs is not None and crs not in ACCEPTABLE_CRS:
- raise ValueError(f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}")
+ raise ValueError(
+ f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}"
+ )
self._crs = crs
-
+
# If positional arguments have been supplied, then this is a possible call to the copy constructor or
# initialisation by a coordinate iterable as per ShapelyPoint constructor.
if len(args) > 0:
@@ -115,7 +118,9 @@ def __init__(self, *args, **kwargs):
if isinstance(args[0], (tuple, list)):
# Check dimensionality of tuple
if len(args[0]) < 2 or len(args[0]) > 3:
- raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}")
+ raise ValueError(
+ f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}"
+ )
x = args[0][0]
y = args[0][1]
if len(args[0]) == 3:
@@ -143,11 +148,15 @@ def __init__(self, *args, **kwargs):
if crs is None:
self._crs = "cartesian-3d"
else:
- raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}")
+ raise ValueError(
+ f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}"
+ )
return
else:
- raise TypeError(f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}")
-
+ raise TypeError(
+ f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}"
+ )
+
# Initialisation is either via x,y[,z] XOR longitude,latitude[,height]. Specifying both leads to an error.
if any(i is not None for i in [x, y, z]) and any(
i is not None for i in [latitude, longitude, height]
@@ -157,14 +166,14 @@ def __init__(self, *args, **kwargs):
"A Point can be defined either by x,y,z coordinates OR latitude,longitude,height but not "
"a combination of these terms"
)
-
+
# Specifying no initialisation argument at this point in the constructor is flagged as an error
if all(i is None for i in [x, y, z, latitude, longitude, height]):
raise ValueError(
"Invalid instantiation via no arguments. "
"A Point needs default values either in x,y,z or longitude, latitude, height coordinates"
)
-
+
# Geographical Point Initialisation
if latitude is not None and longitude is not None:
if height is not None:
@@ -176,7 +185,7 @@ def __init__(self, *args, **kwargs):
self._crs = "wgs-84"
_x = longitude
_y = latitude
-
+
# Geometrical Point Initialisation
if x is not None and y is not None:
if z is not None:
@@ -188,22 +197,26 @@ def __init__(self, *args, **kwargs):
self._crs = "cartesian"
_x = x
_y = y
-
+
if _z is None:
if "-3d" not in self._crs:
super().__init__((float(_x), float(_y)), **kwargs)
else:
- raise ValueError(f"Invalid vector dimensions(2) for given CRS({self._crs}).")
+ raise ValueError(
+ f"Invalid vector dimensions(2) for given CRS({self._crs})."
+ )
else:
if "-3d" in self._crs:
super().__init__((float(_x), float(_y), float(_z)), **kwargs)
else:
- raise ValueError(f"Invalid vector dimensions(3) for given CRS({self._crs}).")
-
+ raise ValueError(
+ f"Invalid vector dimensions(3) for given CRS({self._crs})."
+ )
+
@property
def crs(self):
return self._crs
-
+
@property
def x(self):
if not self._crs.startswith("cartesian"):
@@ -211,7 +224,7 @@ def x(self):
f'Invalid coordinate ("x") for points defined over {self.crs}'
)
return super().x
-
+
@property
def y(self):
if not self._crs.startswith("cartesian"):
@@ -219,7 +232,7 @@ def y(self):
f'Invalid coordinate ("y") for points defined over {self.crs}'
)
return super().y
-
+
@property
def z(self):
if self._crs != "cartesian-3d":
@@ -227,7 +240,7 @@ def z(self):
f'Invalid coordinate ("z") for points defined over {self.crs}'
)
return super().z
-
+
@property
def latitude(self):
if not self._crs.startswith("wgs-84"):
@@ -235,7 +248,7 @@ def latitude(self):
f'Invalid coordinate ("latitude") for points defined over {self.crs}'
)
return super().y
-
+
@property
def longitude(self):
if not self._crs.startswith("wgs-84"):
@@ -243,7 +256,7 @@ def longitude(self):
f'Invalid coordinate ("longitude") for points defined over {self.crs}'
)
return super().x
-
+
@property
def height(self):
if self._crs != "wgs-84-3d":
@@ -251,21 +264,22 @@ def height(self):
f'Invalid coordinate ("height") for points defined over {self.crs}'
)
return super().z
-
+
# The following operations are necessary here due to the way queries (and more importantly their parameters) get
# combined and evaluated in neomodel. Specifically, query expressions get duplicated with deep copies and any valid
# datatype values should also implement these operations.
def __copy__(self):
return NeomodelPoint(self)
-
+
def __deepcopy__(self, memo):
return NeomodelPoint(self)
else:
+
class NeomodelPoint:
"""
Abstracts the Point spatial data type of Neo4j.
-
+
Note:
At the time of writing, Neo4j supports 2 main variants of Point:
1. A generic point defined over a Cartesian plane
@@ -274,12 +288,12 @@ class NeomodelPoint:
* The minimum data to define a point is longitude, latitude [,Height] and the crs is then assumed
to be "wgs-84".
"""
-
+
# def __init__(self, *args, crs=None, x=None, y=None, z=None, latitude=None, longitude=None, height=None, **kwargs):
def __init__(self, *args, **kwargs):
"""
Creates a NeomodelPoint.
-
+
:param args: Positional arguments to emulate the behaviour of Shapely's Point (and specifically the copy
constructor)
:type args: list
@@ -300,7 +314,7 @@ def __init__(self, *args, **kwargs):
:param kwargs: Dictionary of keyword arguments
:type kwargs: dict
"""
-
+
# Python2.7 Workaround for the order that the arguments get passed to the functions
crs = kwargs.pop("crs", None)
x = kwargs.pop("x", None)
@@ -309,14 +323,16 @@ def __init__(self, *args, **kwargs):
longitude = kwargs.pop("longitude", None)
latitude = kwargs.pop("latitude", None)
height = kwargs.pop("height", None)
-
+
_x, _y, _z = None, None, None
-
+
# CRS validity check is common to both types of constructors that follow
if crs is not None and crs not in ACCEPTABLE_CRS:
- raise ValueError(f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}")
+ raise ValueError(
+ f"Invalid CRS({crs}). Expected one of {','.join(ACCEPTABLE_CRS)}"
+ )
self._crs = crs
-
+
# If positional arguments have been supplied, then this is a possible call to the copy constructor or
# initialisation by a coordinate iterable as per ShapelyPoint constructor.
if len(args) > 0:
@@ -324,7 +340,9 @@ def __init__(self, *args, **kwargs):
if isinstance(args[0], (tuple, list)):
# Check dimensionality of tuple
if len(args[0]) < 2 or len(args[0]) > 3:
- raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}")
+ raise ValueError(
+ f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0])}"
+ )
x = args[0][0]
y = args[0][1]
if len(args[0]) == 3:
@@ -354,11 +372,15 @@ def __init__(self, *args, **kwargs):
if crs is None:
self._crs = "cartesian-3d"
else:
- raise ValueError(f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}")
+ raise ValueError(
+ f"Invalid vector dimensions. Expected 2 or 3, received {len(args[0].coords[0])}"
+ )
return
else:
- raise TypeError(f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}")
-
+ raise TypeError(
+ f"Invalid object passed to copy constructor. Expected NeomodelPoint or shapely Point, received {type(args[0])}"
+ )
+
# Initialisation is either via x,y[,z] XOR longitude,latitude[,height]. Specifying both leads to an error.
if any(i is not None for i in [x, y, z]) and any(
i is not None for i in [latitude, longitude, height]
@@ -368,14 +390,14 @@ def __init__(self, *args, **kwargs):
"A Point can be defined either by x,y,z coordinates OR latitude,longitude,height but not "
"a combination of these terms"
)
-
+
# Specifying no initialisation argument at this point in the constructor is flagged as an error
if all(i is None for i in [x, y, z, latitude, longitude, height]):
raise ValueError(
"Invalid instantiation via no arguments. "
"A Point needs default values either in x,y,z or longitude, latitude, height coordinates"
)
-
+
# Geographical Point Initialisation
if latitude is not None and longitude is not None:
if height is not None:
@@ -387,7 +409,7 @@ def __init__(self, *args, **kwargs):
self._crs = "wgs-84"
_x = longitude
_y = latitude
-
+
# Geometrical Point Initialisation
if x is not None and y is not None:
if z is not None:
@@ -399,23 +421,28 @@ def __init__(self, *args, **kwargs):
self._crs = "cartesian"
_x = x
_y = y
-
+
if _z is None:
if "-3d" not in self._crs:
self._shapely_point = ShapelyPoint((float(_x), float(_y)))
else:
- raise ValueError(f"Invalid vector dimensions(2) for given CRS({self._crs}).")
+ raise ValueError(
+ f"Invalid vector dimensions(2) for given CRS({self._crs})."
+ )
else:
if "-3d" in self._crs:
- self._shapely_point = ShapelyPoint((float(_x), float(_y), float(_z)))
+ self._shapely_point = ShapelyPoint(
+ (float(_x), float(_y), float(_z))
+ )
else:
- raise ValueError(f"Invalid vector dimensions(3) for given CRS({self._crs}).")
+ raise ValueError(
+ f"Invalid vector dimensions(3) for given CRS({self._crs})."
+ )
-
@property
def crs(self):
return self._crs
-
+
@property
def x(self):
if not self._crs.startswith("cartesian"):
@@ -423,7 +450,7 @@ def x(self):
f'Invalid coordinate ("x") for points defined over {self.crs}'
)
return self._shapely_point.x
-
+
@property
def y(self):
if not self._crs.startswith("cartesian"):
@@ -431,7 +458,7 @@ def y(self):
f'Invalid coordinate ("y") for points defined over {self.crs}'
)
return self._shapely_point.y
-
+
@property
def z(self):
if self._crs != "cartesian-3d":
@@ -439,7 +466,7 @@ def z(self):
f'Invalid coordinate ("z") for points defined over {self.crs}'
)
return self._shapely_point.z
-
+
@property
def latitude(self):
if not self._crs.startswith("wgs-84"):
@@ -447,7 +474,7 @@ def latitude(self):
f'Invalid coordinate ("latitude") for points defined over {self.crs}'
)
return self._shapely_point.y
-
+
@property
def longitude(self):
if not self._crs.startswith("wgs-84"):
@@ -455,7 +482,7 @@ def longitude(self):
f'Invalid coordinate ("longitude") for points defined over {self.crs}'
)
return self._shapely_point.x
-
+
@property
def height(self):
if self._crs != "wgs-84-3d":
@@ -463,13 +490,13 @@ def height(self):
f'Invalid coordinate ("height") for points defined over {self.crs}'
)
return self._shapely_point.z
-
+
# The following operations are necessary here due to the way queries (and more importantly their parameters) get
# combined and evaluated in neomodel. Specifically, query expressions get duplicated with deep copies and any valid
# datatype values should also implement these operations.
def __copy__(self):
return NeomodelPoint(self)
-
+
def __deepcopy__(self, memo):
return NeomodelPoint(self)
@@ -484,7 +511,9 @@ def __eq__(self, other):
Compare objects by value
"""
if not isinstance(other, (ShapelyPoint, NeomodelPoint)):
- raise ValueException(f"NeomodelPoint equality comparison expected NeomodelPoint or Shapely Point, received {type(other)}")
+ raise ValueException(
+ f"NeomodelPoint equality comparison expected NeomodelPoint or Shapely Point, received {type(other)}"
+ )
else:
if isinstance(other, ShapelyPoint):
return self.coords[0] == other.coords[0]
@@ -517,12 +546,19 @@ def __init__(self, *args, **kwargs):
crs = None
if crs is None or (crs not in ACCEPTABLE_CRS):
- raise ValueError(f"Invalid CRS({crs}). Point properties require CRS to be one of {','.join(ACCEPTABLE_CRS)}")
+ raise ValueError(
+ f"Invalid CRS({crs}). Point properties require CRS to be one of {','.join(ACCEPTABLE_CRS)}"
+ )
# If a default value is passed and it is not a callable, then make sure it is in the right type
- if "default" in kwargs and not hasattr(kwargs["default"], "__call__") and not isinstance(kwargs["default"], NeomodelPoint):
- raise TypeError(f"Invalid default value. Expected NeomodelPoint, received {type(kwargs['default'])}"
- )
+ if (
+ "default" in kwargs
+ and not hasattr(kwargs["default"], "__call__")
+ and not isinstance(kwargs["default"], NeomodelPoint)
+ ):
+ raise TypeError(
+ f"Invalid default value. Expected NeomodelPoint, received {type(kwargs['default'])}"
+ )
super().__init__(*args, **kwargs)
self._crs = crs
@@ -544,10 +580,14 @@ def inflate(self, value):
try:
value_point_crs = SRID_TO_CRS[value.srid]
except KeyError as e:
- raise ValueError(f"Invalid SRID to inflate. Expected one of {SRID_TO_CRS.keys()}, received {value.srid}") from e
+ raise ValueError(
+ f"Invalid SRID to inflate. Expected one of {SRID_TO_CRS.keys()}, received {value.srid}"
+ ) from e
if self._crs != value_point_crs:
- raise ValueError(f"Invalid CRS. Expected POINT defined over {self._crs}, received {value_point_crs}")
+ raise ValueError(
+ f"Invalid CRS. Expected POINT defined over {self._crs}, received {value_point_crs}"
+ )
# cartesian
if value.srid == 7203:
return NeomodelPoint(x=value.x, y=value.y)
@@ -581,7 +621,9 @@ def deflate(self, value):
)
if value.crs != self._crs:
- raise ValueError(f"Invalid CRS. Expected NeomodelPoint defined over {self._crs}, received NeomodelPoint defined over {value.crs}")
+ raise ValueError(
+ f"Invalid CRS. Expected NeomodelPoint defined over {self._crs}, received NeomodelPoint defined over {value.crs}"
+ )
if value.crs == "cartesian-3d":
return neo4j.spatial.CartesianPoint((value.x, value.y, value.z))
diff --git a/neomodel/contrib/semi_structured.py b/neomodel/contrib/sync_/semi_structured.py
similarity index 94%
rename from neomodel/contrib/semi_structured.py
rename to neomodel/contrib/sync_/semi_structured.py
index 9c719983..86a5a140 100644
--- a/neomodel/contrib/semi_structured.py
+++ b/neomodel/contrib/sync_/semi_structured.py
@@ -1,5 +1,5 @@
-from neomodel.core import StructuredNode
from neomodel.exceptions import DeflateConflict, InflateConflict
+from neomodel.sync_.core import StructuredNode
from neomodel.util import _get_node_properties
@@ -57,7 +57,7 @@ def inflate(cls, node):
def deflate(cls, node_props, obj=None, skip_empty=False):
deflated = super().deflate(node_props, obj, skip_empty=skip_empty)
for key in [k for k in node_props if k not in deflated]:
- if hasattr(cls, key) and (getattr(cls,key).required or not skip_empty):
+ if hasattr(cls, key) and (getattr(cls, key).required or not skip_empty):
raise DeflateConflict(cls, key, deflated[key], obj.element_id)
node_props.update(deflated)
diff --git a/neomodel/core.py b/neomodel/core.py
deleted file mode 100644
index 415a97af..00000000
--- a/neomodel/core.py
+++ /dev/null
@@ -1,784 +0,0 @@
-import sys
-import warnings
-from itertools import combinations
-
-from neo4j.exceptions import ClientError
-
-from neomodel import config
-from neomodel.exceptions import (
- DoesNotExist,
- FeatureNotSupported,
- NodeClassAlreadyDefined,
-)
-from neomodel.hooks import hooks
-from neomodel.properties import Property, PropertyManager
-from neomodel.util import Database, _get_node_properties, _UnsavedNode, classproperty
-
-db = Database()
-
-RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists"
-INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists"
-CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists"
-STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg"
-
-
-def drop_constraints(quiet=True, stdout=None):
- """
- Discover and drop all constraints.
-
- :type: bool
- :return: None
- """
- if not stdout or stdout is None:
- stdout = sys.stdout
-
- results, meta = db.cypher_query("SHOW CONSTRAINTS")
-
- results_as_dict = [dict(zip(meta, row)) for row in results]
- for constraint in results_as_dict:
- db.cypher_query("DROP CONSTRAINT " + constraint["name"])
- if not quiet:
- stdout.write(
- (
- " - Dropping unique constraint and index"
- f" on label {constraint['labelsOrTypes'][0]}"
- f" with property {constraint['properties'][0]}.\n"
- )
- )
- if not quiet:
- stdout.write("\n")
-
-
-def drop_indexes(quiet=True, stdout=None):
- """
- Discover and drop all indexes, except the automatically created token lookup indexes.
-
- :type: bool
- :return: None
- """
- if not stdout or stdout is None:
- stdout = sys.stdout
-
- indexes = db.list_indexes(exclude_token_lookup=True)
- for index in indexes:
- db.cypher_query("DROP INDEX " + index["name"])
- if not quiet:
- stdout.write(
- f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n'
- )
- if not quiet:
- stdout.write("\n")
-
-
-def remove_all_labels(stdout=None):
- """
- Calls functions for dropping constraints and indexes.
-
- :param stdout: output stream
- :return: None
- """
-
- if not stdout:
- stdout = sys.stdout
-
- stdout.write("Dropping constraints...\n")
- drop_constraints(quiet=False, stdout=stdout)
-
- stdout.write("Dropping indexes...\n")
- drop_indexes(quiet=False, stdout=stdout)
-
-
-def install_labels(cls, quiet=True, stdout=None):
- """
- Setup labels with indexes and constraints for a given class
-
- :param cls: StructuredNode class
- :type: class
- :param quiet: (default true) enable standard output
- :param stdout: stdout stream
- :type: bool
- :return: None
- """
- if not stdout or stdout is None:
- stdout = sys.stdout
-
- if not hasattr(cls, "__label__"):
- if not quiet:
- stdout.write(
- f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n"
- )
- return
-
- for name, property in cls.defined_properties(aliases=False, rels=False).items():
- _install_node(cls, name, property, quiet, stdout)
-
- for _, relationship in cls.defined_properties(
- aliases=False, rels=True, properties=False
- ).items():
- _install_relationship(cls, relationship, quiet, stdout)
-
-
-def _create_node_index(label: str, property_name: str, stdout):
- try:
- db.cypher_query(
- f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); "
- )
- except ClientError as e:
- if e.code in (
- RULE_ALREADY_EXISTS,
- INDEX_ALREADY_EXISTS,
- ):
- stdout.write(f"{str(e)}\n")
- else:
- raise
-
-
-def _create_node_constraint(label: str, property_name: str, stdout):
- try:
- db.cypher_query(
- f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name}
- FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE"""
- )
- except ClientError as e:
- if e.code in (
- RULE_ALREADY_EXISTS,
- CONSTRAINT_ALREADY_EXISTS,
- ):
- stdout.write(f"{str(e)}\n")
- else:
- raise
-
-
-def _create_relationship_index(relationship_type: str, property_name: str, stdout):
- try:
- db.cypher_query(
- f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); "
- )
- except ClientError as e:
- if e.code in (
- RULE_ALREADY_EXISTS,
- INDEX_ALREADY_EXISTS,
- ):
- stdout.write(f"{str(e)}\n")
- else:
- raise
-
-
-def _create_relationship_constraint(relationship_type: str, property_name: str, stdout):
- if db.version_is_higher_than("5.7"):
- try:
- db.cypher_query(
- f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name}
- FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE"""
- )
- except ClientError as e:
- if e.code in (
- RULE_ALREADY_EXISTS,
- CONSTRAINT_ALREADY_EXISTS,
- ):
- stdout.write(f"{str(e)}\n")
- else:
- raise
- else:
- raise FeatureNotSupported(
- f"Unique indexes on relationships are not supported in Neo4j version {db.database_version}. Please upgrade to Neo4j 5.7 or higher."
- )
-
-
-def _install_node(cls, name, property, quiet, stdout):
- # Create indexes and constraints for node property
- db_property = property.db_property or name
- if property.index:
- if not quiet:
- stdout.write(
- f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n"
- )
- _create_node_index(
- label=cls.__label__, property_name=db_property, stdout=stdout
- )
-
- elif property.unique_index:
- if not quiet:
- stdout.write(
- f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n"
- )
- _create_node_constraint(
- label=cls.__label__, property_name=db_property, stdout=stdout
- )
-
-
-def _install_relationship(cls, relationship, quiet, stdout):
- # Create indexes and constraints for relationship property
- relationship_cls = relationship.definition["model"]
- if relationship_cls is not None:
- relationship_type = relationship.definition["relation_type"]
- for prop_name, property in relationship_cls.defined_properties(
- aliases=False, rels=False
- ).items():
- db_property = property.db_property or prop_name
- if property.index:
- if not quiet:
- stdout.write(
- f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n"
- )
- _create_relationship_index(
- relationship_type=relationship_type,
- property_name=db_property,
- stdout=stdout,
- )
- elif property.unique_index:
- if not quiet:
- stdout.write(
- f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n"
- )
- _create_relationship_constraint(
- relationship_type=relationship_type,
- property_name=db_property,
- stdout=stdout,
- )
-
-
-def install_all_labels(stdout=None):
- """
- Discover all subclasses of StructuredNode in your application and execute install_labels on each.
- Note: code must be loaded (imported) in order for a class to be discovered.
-
- :param stdout: output stream
- :return: None
- """
-
- if not stdout or stdout is None:
- stdout = sys.stdout
-
- def subsub(cls): # recursively return all subclasses
- subclasses = cls.__subclasses__()
- if not subclasses: # base case: no more subclasses
- return []
- return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)]
-
- stdout.write("Setting up indexes and constraints...\n\n")
-
- i = 0
- for cls in subsub(StructuredNode):
- stdout.write(f"Found {cls.__module__}.{cls.__name__}\n")
- install_labels(cls, quiet=False, stdout=stdout)
- i += 1
-
- if i:
- stdout.write("\n")
-
- stdout.write(f"Finished {i} classes.\n")
-
-
-class NodeMeta(type):
- def __new__(mcs, name, bases, namespace):
- namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {})
- cls = super().__new__(mcs, name, bases, namespace)
- cls.DoesNotExist._model_class = cls
-
- if hasattr(cls, "__abstract_node__"):
- delattr(cls, "__abstract_node__")
- else:
- if "deleted" in namespace:
- raise ValueError(
- "Property name 'deleted' is not allowed as it conflicts with neomodel internals."
- )
- elif "id" in namespace:
- raise ValueError(
- """
- Property name 'id' is not allowed as it conflicts with neomodel internals.
- Consider using 'uid' or 'identifier' as id is also a Neo4j internal.
- """
- )
- elif "element_id" in namespace:
- raise ValueError(
- """
- Property name 'element_id' is not allowed as it conflicts with neomodel internals.
- Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal.
- """
- )
- for key, value in (
- (x, y) for x, y in namespace.items() if isinstance(y, Property)
- ):
- value.name, value.owner = key, cls
- if hasattr(value, "setup") and callable(value.setup):
- value.setup()
-
- # cache various groups of properies
- cls.__required_properties__ = tuple(
- name
- for name, property in cls.defined_properties(
- aliases=False, rels=False
- ).items()
- if property.required or property.unique_index
- )
- cls.__all_properties__ = tuple(
- cls.defined_properties(aliases=False, rels=False).items()
- )
- cls.__all_aliases__ = tuple(
- cls.defined_properties(properties=False, rels=False).items()
- )
- cls.__all_relationships__ = tuple(
- cls.defined_properties(aliases=False, properties=False).items()
- )
-
- cls.__label__ = namespace.get("__label__", name)
- cls.__optional_labels__ = namespace.get("__optional_labels__", [])
-
- if config.AUTO_INSTALL_LABELS:
- install_labels(cls, quiet=False)
-
- build_class_registry(cls)
-
- return cls
-
-
-def build_class_registry(cls):
- base_label_set = frozenset(cls.inherited_labels())
- optional_label_set = set(cls.inherited_optional_labels())
-
- # Construct all possible combinations of labels + optional labels
- possible_label_combinations = [
- frozenset(set(x).union(base_label_set))
- for i in range(1, len(optional_label_set) + 1)
- for x in combinations(optional_label_set, i)
- ]
- possible_label_combinations.append(base_label_set)
-
- for label_set in possible_label_combinations:
- if label_set not in db._NODE_CLASS_REGISTRY:
- db._NODE_CLASS_REGISTRY[label_set] = cls
- else:
- raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY)
-
-
-NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True})
-
-
-class StructuredNode(NodeBase):
- """
- Base class for all node definitions to inherit from.
-
- If you want to create your own abstract classes set:
- __abstract_node__ = True
- """
-
- # static properties
-
- __abstract_node__ = True
-
- # magic methods
-
- def __init__(self, *args, **kwargs):
- if "deleted" in kwargs:
- raise ValueError("deleted property is reserved for neomodel")
-
- for key, val in self.__all_relationships__:
- self.__dict__[key] = val.build_manager(self, key)
-
- super().__init__(*args, **kwargs)
-
- def __eq__(self, other):
- if not isinstance(other, (StructuredNode,)):
- return False
- if hasattr(self, "element_id") and hasattr(other, "element_id"):
- return self.element_id == other.element_id
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __repr__(self):
- return f"<{self.__class__.__name__}: {self}>"
-
- def __str__(self):
- return repr(self.__properties__)
-
- # dynamic properties
-
- @classproperty
- def nodes(cls):
- """
- Returns a NodeSet object representing all nodes of the classes label
- :return: NodeSet
- :rtype: NodeSet
- """
- from .match import NodeSet
-
- return NodeSet(cls)
-
- @property
- def element_id(self):
- if hasattr(self, "element_id_property"):
- return (
- int(self.element_id_property)
- if db.database_version.startswith("4")
- else self.element_id_property
- )
- return None
-
- # Version 4.4 support - id is deprecated in version 5.x
- @property
- def id(self):
- try:
- return int(self.element_id_property)
- except (TypeError, ValueError):
- raise ValueError(
- "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
- )
-
- # methods
-
- @classmethod
- def _build_merge_query(
- cls, merge_params, update_existing=False, lazy=False, relationship=None
- ):
- """
- Get a tuple of a CYPHER query and a params dict for the specified MERGE query.
-
- :param merge_params: The target node match parameters, each node must have a "create" key and optional "update".
- :type merge_params: list of dict
- :param update_existing: True to update properties of existing nodes, default False to keep existing values.
- :type update_existing: bool
- :rtype: tuple
- """
- query_params = dict(merge_params=merge_params)
- n_merge_labels = ":".join(cls.inherited_labels())
- n_merge_prm = ", ".join(
- (
- f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}"
- for p in cls.__required_properties__
- )
- )
- n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}"
- if relationship is None:
- # create "simple" unwind query
- query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n "
- else:
- # validate relationship
- if not isinstance(relationship.source, StructuredNode):
- raise ValueError(
- f"relationship source [{repr(relationship.source)}] is not a StructuredNode"
- )
- relation_type = relationship.definition.get("relation_type")
- if not relation_type:
- raise ValueError(
- "No relation_type is specified on provided relationship"
- )
-
- from .match import _rel_helper
-
- query_params["source_id"] = relationship.source.element_id
- query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n "
- query += "WITH source\n UNWIND $merge_params as params \n "
- query += "MERGE "
- query += _rel_helper(
- lhs="source",
- rhs=n_merge,
- ident=None,
- relation_type=relation_type,
- direction=relationship.definition["direction"],
- )
-
- query += "ON CREATE SET n = params.create\n "
- # if update_existing, write properties on match as well
- if update_existing is True:
- query += "ON MATCH SET n += params.update\n"
-
- # close query
- if lazy:
- query += f"RETURN {db.get_id_method()}(n)"
- else:
- query += "RETURN n"
-
- return query, query_params
-
- @classmethod
- def create(cls, *props, **kwargs):
- """
- Call to CREATE with parameters map. A new instance will be created and saved.
-
- :param props: dict of properties to create the nodes.
- :type props: tuple
- :param lazy: False by default, specify True to get nodes with id only without the parameters.
- :type: bool
- :rtype: list
- """
-
- if "streaming" in kwargs:
- warnings.warn(
- STREAMING_WARNING,
- category=DeprecationWarning,
- stacklevel=1,
- )
-
- lazy = kwargs.get("lazy", False)
- # create mapped query
- query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)"
-
- # close query
- if lazy:
- query += f" RETURN {db.get_id_method()}(n)"
- else:
- query += " RETURN n"
-
- results = []
- for item in [
- cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props
- ]:
- node, _ = db.cypher_query(query, {"create_params": item})
- results.extend(node[0])
-
- nodes = [cls.inflate(node) for node in results]
-
- if not lazy and hasattr(cls, "post_create"):
- for node in nodes:
- node.post_create()
-
- return nodes
-
- @classmethod
- def create_or_update(cls, *props, **kwargs):
- """
- Call to MERGE with parameters map. A new instance will be created and saved if does not already exists,
- this is an atomic operation. If an instance already exists all optional properties specified will be updated.
-
- Note that the post_create hook isn't called after create_or_update
-
- :param props: List of dict arguments to get or create the entities with.
- :type props: tuple
- :param relationship: Optional, relationship to get/create on when new entity is created.
- :param lazy: False by default, specify True to get nodes with id only without the parameters.
- :rtype: list
- """
- lazy = kwargs.get("lazy", False)
- relationship = kwargs.get("relationship")
-
- # build merge query, make sure to update only explicitly specified properties
- create_or_update_params = []
- for specified, deflated in [
- (p, cls.deflate(p, skip_empty=True)) for p in props
- ]:
- create_or_update_params.append(
- {
- "create": deflated,
- "update": dict(
- (k, v) for k, v in deflated.items() if k in specified
- ),
- }
- )
- query, params = cls._build_merge_query(
- create_or_update_params,
- update_existing=True,
- relationship=relationship,
- lazy=lazy,
- )
-
- if "streaming" in kwargs:
- warnings.warn(
- STREAMING_WARNING,
- category=DeprecationWarning,
- stacklevel=1,
- )
-
- # fetch and build instance for each result
- results = db.cypher_query(query, params)
- return [cls.inflate(r[0]) for r in results[0]]
-
- def cypher(self, query, params=None):
- """
- Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id.
-
- :param query: cypher query string
- :type: string
- :param params: query parameters
- :type: dict
- :return: list containing query results
- :rtype: list
- """
- self._pre_action_check("cypher")
- params = params or {}
- params.update({"self": self.element_id})
- return db.cypher_query(query, params)
-
- @hooks
- def delete(self):
- """
- Delete a node and its relationships
-
- :return: True
- """
- self._pre_action_check("delete")
- self.cypher(
- f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self"
- )
- delattr(self, "element_id_property")
- self.deleted = True
- return True
-
- @classmethod
- def get_or_create(cls, *props, **kwargs):
- """
- Call to MERGE with parameters map. A new instance will be created and saved if does not already exist,
- this is an atomic operation.
- Parameters must contain all required properties, any non required properties with defaults will be generated.
-
- Note that the post_create hook isn't called after get_or_create
-
- :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create
- the entities with.
- :type props: tuple
- :param relationship: Optional, relationship to get/create on when new entity is created.
- :param lazy: False by default, specify True to get nodes with id only without the parameters.
- :rtype: list
- """
- lazy = kwargs.get("lazy", False)
- relationship = kwargs.get("relationship")
-
- # build merge query
- get_or_create_params = [
- {"create": cls.deflate(p, skip_empty=True)} for p in props
- ]
- query, params = cls._build_merge_query(
- get_or_create_params, relationship=relationship, lazy=lazy
- )
-
- if "streaming" in kwargs:
- warnings.warn(
- STREAMING_WARNING,
- category=DeprecationWarning,
- stacklevel=1,
- )
-
- # fetch and build instance for each result
- results = db.cypher_query(query, params)
- return [cls.inflate(r[0]) for r in results[0]]
-
- @classmethod
- def inflate(cls, node):
- """
- Inflate a raw neo4j_driver node to a neomodel node
- :param node:
- :return: node object
- """
- # support lazy loading
- if isinstance(node, str) or isinstance(node, int):
- snode = cls()
- snode.element_id_property = node
- else:
- node_properties = _get_node_properties(node)
- props = {}
- for key, prop in cls.__all_properties__:
- # map property name from database to object property
- db_property = prop.db_property or key
-
- if db_property in node_properties:
- props[key] = prop.inflate(node_properties[db_property], node)
- elif prop.has_default:
- props[key] = prop.default_value()
- else:
- props[key] = None
-
- snode = cls(**props)
- snode.element_id_property = node.element_id
-
- return snode
-
- @classmethod
- def inherited_labels(cls):
- """
- Return list of labels from nodes class hierarchy.
-
- :return: list
- """
- return [
- scls.__label__
- for scls in cls.mro()
- if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__")
- ]
-
- @classmethod
- def inherited_optional_labels(cls):
- """
- Return list of optional labels from nodes class hierarchy.
-
- :return: list
- :rtype: list
- """
- return [
- label
- for scls in cls.mro()
- for label in getattr(scls, "__optional_labels__", [])
- if not hasattr(scls, "__abstract_node__")
- ]
-
- def labels(self):
- """
- Returns list of labels tied to the node from neo4j.
-
- :return: list of labels
- :rtype: list
- """
- self._pre_action_check("labels")
- return self.cypher(
- f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)"
- )[0][0][0]
-
- def _pre_action_check(self, action):
- if hasattr(self, "deleted") and self.deleted:
- raise ValueError(
- f"{self.__class__.__name__}.{action}() attempted on deleted node"
- )
- if not hasattr(self, "element_id"):
- raise ValueError(
- f"{self.__class__.__name__}.{action}() attempted on unsaved node"
- )
-
- def refresh(self):
- """
- Reload the node from neo4j
- """
- self._pre_action_check("refresh")
- if hasattr(self, "element_id"):
- request = self.cypher(
- f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n"
- )[0]
- if not request or not request[0]:
- raise self.__class__.DoesNotExist("Can't refresh non existent node")
- node = self.inflate(request[0][0])
- for key, val in node.__properties__.items():
- setattr(self, key, val)
- else:
- raise ValueError("Can't refresh unsaved node")
-
- @hooks
- def save(self):
- """
- Save the node to neo4j or raise an exception
-
- :return: the node instance
- """
-
- # create or update instance node
- if hasattr(self, "element_id_property"):
- # update
- params = self.deflate(self.__properties__, self)
- query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n"
-
- if params:
- query += "SET "
- query += ",\n".join([f"n.{key} = ${key}" for key in params])
- query += "\n"
- if self.inherited_labels():
- query += "\n".join(
- [f"SET n:`{label}`" for label in self.inherited_labels()]
- )
- self.cypher(query, params)
- elif hasattr(self, "deleted") and self.deleted:
- raise ValueError(
- f"{self.__class__.__name__}.save() attempted on deleted node"
- )
- else: # create
- created_node = self.create(self.__properties__)[0]
- self.element_id_property = created_node.element_id
- return self
diff --git a/neomodel/integration/numpy.py b/neomodel/integration/numpy.py
index a04508c4..14bae3df 100644
--- a/neomodel/integration/numpy.py
+++ b/neomodel/integration/numpy.py
@@ -7,7 +7,7 @@
Example:
- >>> from neomodel import db
+ >>> from neomodel.async_ import db
>>> from neomodel.integration.numpy import to_nparray
>>> db.set_connection('bolt://neo4j:secret@localhost:7687')
>>> df = to_nparray(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name"))
diff --git a/neomodel/integration/pandas.py b/neomodel/integration/pandas.py
index 1ad19871..2f809ade 100644
--- a/neomodel/integration/pandas.py
+++ b/neomodel/integration/pandas.py
@@ -7,7 +7,7 @@
Example:
- >>> from neomodel import db
+ >>> from neomodel.async_ import db
>>> from neomodel.integration.pandas import to_dataframe
>>> db.set_connection('bolt://neo4j:secret@localhost:7687')
>>> df = to_dataframe(db.cypher_query("MATCH (u:User) RETURN u.email AS email, u.name AS name"))
diff --git a/neomodel/match_q.py b/neomodel/match_q.py
index 7e76a23b..4e45588c 100644
--- a/neomodel/match_q.py
+++ b/neomodel/match_q.py
@@ -69,7 +69,11 @@ def _new_instance(cls, children=None, connector=None, negated=False):
return obj
def __str__(self):
- return f"(NOT ({self.connector}: {', '.join(str(c) for c in self.children)}))" if self.negated else f"({self.connector}: {', '.join(str(c) for c in self.children)})"
+ return (
+ f"(NOT ({self.connector}: {', '.join(str(c) for c in self.children)}))"
+ if self.negated
+ else f"({self.connector}: {', '.join(str(c) for c in self.children)})"
+ )
def __repr__(self):
return f"<{self.__class__.__name__}: {self}>"
diff --git a/neomodel/properties.py b/neomodel/properties.py
index e28b4ead..cac92dfe 100644
--- a/neomodel/properties.py
+++ b/neomodel/properties.py
@@ -2,7 +2,6 @@
import json
import re
import sys
-import types
import uuid
from datetime import date, datetime
@@ -10,117 +9,12 @@
import pytz
from neomodel import config
-from neomodel.exceptions import DeflateError, InflateError, RequiredProperty
+from neomodel.exceptions import DeflateError, InflateError
if sys.version_info >= (3, 0):
Unicode = str
-def display_for(key):
- def display_choice(self):
- return getattr(self.__class__, key).choices[getattr(self, key)]
-
- return display_choice
-
-
-class PropertyManager:
- """
- Common methods for handling properties on node and relationship objects.
- """
-
- def __init__(self, **kwargs):
- properties = getattr(self, "__all_properties__", None)
- if properties is None:
- properties = self.defined_properties(rels=False, aliases=False).items()
- for name, property in properties:
- if kwargs.get(name) is None:
- if getattr(property, "has_default", False):
- setattr(self, name, property.default_value())
- else:
- setattr(self, name, None)
- else:
- setattr(self, name, kwargs[name])
-
- if getattr(property, "choices", None):
- setattr(
- self,
- f"get_{name}_display",
- types.MethodType(display_for(name), self),
- )
-
- if name in kwargs:
- del kwargs[name]
-
- aliases = getattr(self, "__all_aliases__", None)
- if aliases is None:
- aliases = self.defined_properties(
- aliases=True, rels=False, properties=False
- ).items()
- for name, property in aliases:
- if name in kwargs:
- setattr(self, name, kwargs[name])
- del kwargs[name]
-
- # undefined properties (for magic @prop.setters etc)
- for name, property in kwargs.items():
- setattr(self, name, property)
-
- @property
- def __properties__(self):
- from .relationship_manager import RelationshipManager
-
- return dict(
- (name, value)
- for name, value in vars(self).items()
- if not name.startswith("_")
- and not callable(value)
- and not isinstance(
- value,
- (
- RelationshipManager,
- AliasProperty,
- ),
- )
- )
-
- @classmethod
- def deflate(cls, properties, obj=None, skip_empty=False):
- # deflate dict ready to be stored
- deflated = {}
- for name, property in cls.defined_properties(aliases=False, rels=False).items():
- db_property = property.db_property or name
- if properties.get(name) is not None:
- deflated[db_property] = property.deflate(properties[name], obj)
- elif property.has_default:
- deflated[db_property] = property.deflate(property.default_value(), obj)
- elif property.required:
- raise RequiredProperty(name, cls)
- elif not skip_empty:
- deflated[db_property] = None
- return deflated
-
- @classmethod
- def defined_properties(cls, aliases=True, properties=True, rels=True):
- from .relationship_manager import RelationshipDefinition
-
- props = {}
- for baseclass in reversed(cls.__mro__):
- props.update(
- dict(
- (name, property)
- for name, property in vars(baseclass).items()
- if (aliases and isinstance(property, AliasProperty))
- or (
- properties
- and isinstance(property, Property)
- and not isinstance(property, AliasProperty)
- )
- or (rels and isinstance(property, RelationshipDefinition))
- )
- )
- return props
-
-
def validator(fn):
fn_name = fn.func_name if hasattr(fn, "func_name") else fn.__name__
if fn_name == "inflate":
@@ -467,7 +361,7 @@ def deflate(self, value):
class DateTimeFormatProperty(Property):
"""
- Store a datetime by custome format
+ Store a datetime by custom format
:param default_now: If ``True``, the creation time (Local) will be used as default.
Defaults to ``False``.
:param format: Date format string, default is %Y-%m-%d
diff --git a/neomodel/scripts/neomodel_inspect_database.py b/neomodel/scripts/neomodel_inspect_database.py
index a8cf4c0e..3147ebdf 100644
--- a/neomodel/scripts/neomodel_inspect_database.py
+++ b/neomodel/scripts/neomodel_inspect_database.py
@@ -1,7 +1,7 @@
"""
.. _neomodel_inspect_database:
-``_neomodel_inspect_database``
+``neomodel_inspect_database``
---------------------------
::
@@ -17,6 +17,8 @@
If a file is specified, the tool will write the class definitions to that file.
If no file is specified, the tool will print the class definitions to stdout.
+
+ Note : this script only has a synchronous mode.
options:
-h, --help show this help message and exit
@@ -33,7 +35,7 @@
import textwrap
from os import environ
-from neomodel import db
+from neomodel.sync_.core import db
IMPORTS = []
diff --git a/neomodel/scripts/neomodel_install_labels.py b/neomodel/scripts/neomodel_install_labels.py
index 8bd5119f..8aa7a73b 100755
--- a/neomodel/scripts/neomodel_install_labels.py
+++ b/neomodel/scripts/neomodel_install_labels.py
@@ -14,6 +14,8 @@
If a connection URL is not specified, the tool will look up the environment
variable NEO4J_BOLT_URL. If that environment variable is not set, the tool
will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687
+
+ Note : this script only has a synchronous mode.
positional arguments:
@@ -32,7 +34,7 @@
from importlib import import_module
from os import environ, path
-from .. import db, install_all_labels
+from neomodel.sync_.core import db
def load_python_module_or_file(name):
@@ -111,7 +113,7 @@ def main():
print(f"Connecting to {bolt_url}")
db.set_connection(url=bolt_url)
- install_all_labels()
+ db.install_all_labels()
if __name__ == "__main__":
diff --git a/neomodel/scripts/neomodel_remove_labels.py b/neomodel/scripts/neomodel_remove_labels.py
index 1ad6cc34..79e79390 100755
--- a/neomodel/scripts/neomodel_remove_labels.py
+++ b/neomodel/scripts/neomodel_remove_labels.py
@@ -14,6 +14,8 @@
If a connection URL is not specified, the tool will look up the environment
variable NEO4J_BOLT_URL. If that environment variable is not set, the tool
will attempt to connect to the default URL bolt://neo4j:neo4j@localhost:7687
+
+ Note : this script only has a synchronous mode.
options:
-h, --help show this help message and exit
@@ -27,7 +29,7 @@
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from os import environ
-from .. import db, remove_all_labels
+from neomodel.sync_.core import db
def main():
@@ -63,7 +65,7 @@ def main():
print(f"Connecting to {bolt_url}")
db.set_connection(url=bolt_url)
- remove_all_labels()
+ db.remove_all_labels()
if __name__ == "__main__":
diff --git a/neomodel/sync_/__init__.py b/neomodel/sync_/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/neomodel/cardinality.py b/neomodel/sync_/cardinality.py
similarity index 95%
rename from neomodel/cardinality.py
rename to neomodel/sync_/cardinality.py
index 099bf578..716d173f 100644
--- a/neomodel/cardinality.py
+++ b/neomodel/sync_/cardinality.py
@@ -1,5 +1,5 @@
from neomodel.exceptions import AttemptedCardinalityViolation, CardinalityViolation
-from neomodel.relationship_manager import ( # pylint:disable=unused-import
+from neomodel.sync_.relationship_manager import ( # pylint:disable=unused-import
RelationshipManager,
ZeroOrMore,
)
@@ -37,7 +37,7 @@ def connect(self, node, properties=None):
:type: dict
:return: True / rel instance
"""
- if len(self):
+ if super().__len__():
raise AttemptedCardinalityViolation(
f"Node already has {self} can't connect more"
)
@@ -130,6 +130,6 @@ def connect(self, node, properties=None):
"""
if not hasattr(self.source, "element_id") or self.source.element_id is None:
raise ValueError("Node has not been saved cannot connect!")
- if len(self):
+ if super().__len__():
raise AttemptedCardinalityViolation("Node already has one relationship")
return super().connect(node, properties)
diff --git a/neomodel/sync_/core.py b/neomodel/sync_/core.py
new file mode 100644
index 00000000..8778adb0
--- /dev/null
+++ b/neomodel/sync_/core.py
@@ -0,0 +1,1519 @@
+import logging
+import os
+import sys
+import time
+import warnings
+from asyncio import iscoroutinefunction
+from itertools import combinations
+from threading import local
+from typing import Optional, Sequence
+from urllib.parse import quote, unquote, urlparse
+
+from neo4j import (
+ DEFAULT_DATABASE,
+ Driver,
+ GraphDatabase,
+ Result,
+ Session,
+ Transaction,
+ basic_auth,
+)
+from neo4j.api import Bookmarks
+from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired
+from neo4j.graph import Node, Path, Relationship
+
+from neomodel import config
+from neomodel._async_compat.util import Util
+from neomodel.exceptions import (
+ ConstraintValidationFailed,
+ DoesNotExist,
+ FeatureNotSupported,
+ NodeClassAlreadyDefined,
+ NodeClassNotDefined,
+ RelationshipClassNotDefined,
+ UniqueProperty,
+)
+from neomodel.hooks import hooks
+from neomodel.properties import Property
+from neomodel.sync_.property_manager import PropertyManager
+from neomodel.util import (
+ _get_node_properties,
+ _UnsavedNode,
+ classproperty,
+ deprecated,
+ version_tag_to_integer,
+)
+
+logger = logging.getLogger(__name__)
+
+RULE_ALREADY_EXISTS = "Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists"
+INDEX_ALREADY_EXISTS = "Neo.ClientError.Schema.IndexAlreadyExists"
+CONSTRAINT_ALREADY_EXISTS = "Neo.ClientError.Schema.ConstraintAlreadyExists"
+STREAMING_WARNING = "streaming is not supported by bolt, please remove the kwarg"
+NOT_COROUTINE_ERROR = "The decorated function must be a coroutine"
+
+
+# make sure the connection url has been set prior to executing the wrapped function
+def ensure_connection(func):
+ """Decorator that ensures a connection is established before executing the decorated function.
+
+ Args:
+ func (callable): The function to be decorated.
+
+ Returns:
+ callable: The decorated function.
+
+ """
+
+ def wrapper(self, *args, **kwargs):
+ # Sort out where to find url
+ if hasattr(self, "db"):
+ _db = self.db
+ else:
+ _db = self
+
+ if not _db.driver:
+ if hasattr(config, "DATABASE_URL") and config.DATABASE_URL:
+ _db.set_connection(url=config.DATABASE_URL)
+ elif hasattr(config, "DRIVER") and config.DRIVER:
+ _db.set_connection(driver=config.DRIVER)
+
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
+class Database(local):
+ """
+ A singleton object via which all operations from neomodel to the Neo4j backend are handled with.
+ """
+
+ _NODE_CLASS_REGISTRY = {}
+
+ def __init__(self):
+ self._active_transaction = None
+ self.url = None
+ self.driver = None
+ self._session = None
+ self._pid = None
+ self._database_name = DEFAULT_DATABASE
+ self.protocol_version = None
+ self._database_version = None
+ self._database_edition = None
+ self.impersonated_user = None
+
+ def set_connection(self, url: str = None, driver: Driver = None):
+ """
+ Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance.
+
+ Args:
+ url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname.
+ When provided, a Neo4j driver instance will be created by neomodel.
+
+ driver (neo4j.Driver): Optionally, a pre-created driver instance.
+ When provided, neomodel will not create a driver instance but use this one instead.
+ """
+ if driver:
+ self.driver = driver
+ if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME:
+ self._database_name = config.DATABASE_NAME
+ elif url:
+ self._parse_driver_from_url(url=url)
+
+ self._pid = os.getpid()
+ self._active_transaction = None
+ # Set to default database if it hasn't been set before
+ if self._database_name is None:
+ self._database_name = DEFAULT_DATABASE
+
+ # Getting the information about the database version requires a connection to the database
+ self._database_version = None
+ self._database_edition = None
+ self._update_database_version()
+
+ def _parse_driver_from_url(self, url: str) -> None:
+ """Parse the driver information from the given URL and initialize the driver.
+
+ Args:
+ url (str): The URL to parse.
+
+ Raises:
+ ValueError: If the URL format is not as expected.
+
+ Returns:
+ None - Sets the driver and database_name as class properties
+ """
+ p_start = url.replace(":", "", 1).find(":") + 2
+ p_end = url.rfind("@")
+ password = url[p_start:p_end]
+ url = url.replace(password, quote(password))
+ parsed_url = urlparse(url)
+
+ valid_schemas = [
+ "bolt",
+ "bolt+s",
+ "bolt+ssc",
+ "bolt+routing",
+ "neo4j",
+ "neo4j+s",
+ "neo4j+ssc",
+ ]
+
+ if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas:
+ credentials, hostname = parsed_url.netloc.rsplit("@", 1)
+ username, password = credentials.split(":")
+ password = unquote(password)
+ database_name = parsed_url.path.strip("/")
+ else:
+ raise ValueError(
+ f"Expecting url format: bolt://user:password@localhost:7687 got {url}"
+ )
+
+ options = {
+ "auth": basic_auth(username, password),
+ "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT,
+ "connection_timeout": config.CONNECTION_TIMEOUT,
+ "keep_alive": config.KEEP_ALIVE,
+ "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME,
+ "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE,
+ "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME,
+ "resolver": config.RESOLVER,
+ "user_agent": config.USER_AGENT,
+ }
+
+ if "+s" not in parsed_url.scheme:
+ options["encrypted"] = config.ENCRYPTED
+ options["trusted_certificates"] = config.TRUSTED_CERTIFICATES
+
+ self.driver = GraphDatabase.driver(
+ parsed_url.scheme + "://" + hostname, **options
+ )
+ self.url = url
+ # The database name can be provided through the url or the config
+ if database_name == "":
+ if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME:
+ self._database_name = config.DATABASE_NAME
+ else:
+ self._database_name = database_name
+
+ def close_connection(self):
+ """
+ Closes the currently open driver.
+ The driver should always be closed at the end of the application's lifecyle.
+ """
+ self._database_version = None
+ self._database_edition = None
+ self._database_name = None
+ self.driver.close()
+ self.driver = None
+
+ @property
+ def database_version(self):
+ if self._database_version is None:
+ self._update_database_version()
+
+ return self._database_version
+
+ @property
+ def database_edition(self):
+ if self._database_edition is None:
+ self._update_database_version()
+
+ return self._database_edition
+
+ @property
+ def transaction(self):
+ """
+ Returns the current transaction object
+ """
+ return TransactionProxy(self)
+
+ @property
+ def write_transaction(self):
+ return TransactionProxy(self, access_mode="WRITE")
+
+ @property
+ def read_transaction(self):
+ return TransactionProxy(self, access_mode="READ")
+
+ def impersonate(self, user: str) -> "ImpersonationHandler":
+ """All queries executed within this context manager will be executed as impersonated user
+
+ Args:
+ user (str): User to impersonate
+
+ Returns:
+ ImpersonationHandler: Context manager to set/unset the user to impersonate
+ """
+ db_edition = self.database_edition
+ if db_edition != "enterprise":
+ raise FeatureNotSupported(
+ "Impersonation is only available in Neo4j Enterprise edition"
+ )
+ return ImpersonationHandler(self, impersonated_user=user)
+
+ @ensure_connection
+ def begin(self, access_mode=None, **parameters):
+ """
+ Begins a new transaction. Raises SystemError if a transaction is already active.
+ """
+ if (
+ hasattr(self, "_active_transaction")
+ and self._active_transaction is not None
+ ):
+ raise SystemError("Transaction in progress")
+ self._session: Session = self.driver.session(
+ default_access_mode=access_mode,
+ database=self._database_name,
+ impersonated_user=self.impersonated_user,
+ **parameters,
+ )
+ self._active_transaction: Transaction = self._session.begin_transaction()
+
+ @ensure_connection
+ def commit(self):
+ """
+ Commits the current transaction and closes its session
+
+ :return: last_bookmarks
+ """
+ try:
+ self._active_transaction.commit()
+ last_bookmarks: Bookmarks = self._session.last_bookmarks()
+ finally:
+ # In case when something went wrong during
+ # committing changes to the database
+ # we have to close an active transaction and session.
+ self._active_transaction.close()
+ self._session.close()
+ self._active_transaction = None
+ self._session = None
+
+ return last_bookmarks
+
+ @ensure_connection
+ def rollback(self):
+ """
+ Rolls back the current transaction and closes its session
+ """
+ try:
+ self._active_transaction.rollback()
+ finally:
+ # In case when something went wrong during changes rollback,
+ # we have to close an active transaction and session
+ self._active_transaction.close()
+ self._session.close()
+ self._active_transaction = None
+ self._session = None
+
+ def _update_database_version(self):
+ """
+ Updates the database server information when it is required
+ """
+ try:
+ results = self.cypher_query(
+ "CALL dbms.components() yield versions, edition return versions[0], edition"
+ )
+ self._database_version = results[0][0][0]
+ self._database_edition = results[0][0][1]
+ except ServiceUnavailable:
+ # The database server is not running yet
+ pass
+
+ def _object_resolution(self, object_to_resolve):
+ """
+ Performs in place automatic object resolution on a result
+ returned by cypher_query.
+
+ The function operates recursively in order to be able to resolve Nodes
+ within nested list structures and Path objects. Not meant to be called
+ directly, used primarily by _result_resolution.
+
+ :param object_to_resolve: A result as returned by cypher_query.
+ :type Any:
+
+ :return: An instantiated object.
+ """
+ # Below is the original comment that came with the code extracted in
+ # this method. It is not very clear but I decided to keep it just in
+ # case
+ #
+ #
+ # For some reason, while the type of `a_result_attribute[1]`
+ # as reported by the neo4j driver is `Node` for Node-type data
+ # retrieved from the database.
+ # When the retrieved data are Relationship-Type,
+ # the returned type is `abc.[REL_LABEL]` which is however
+ # a descendant of Relationship.
+ # Consequently, the type checking was changed for both
+ # Node, Relationship objects
+ if isinstance(object_to_resolve, Node):
+ return self._NODE_CLASS_REGISTRY[
+ frozenset(object_to_resolve.labels)
+ ].inflate(object_to_resolve)
+
+ if isinstance(object_to_resolve, Relationship):
+ rel_type = frozenset([object_to_resolve.type])
+ return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve)
+
+ if isinstance(object_to_resolve, Path):
+ from neomodel.sync_.path import NeomodelPath
+
+ return NeomodelPath(object_to_resolve)
+
+ if isinstance(object_to_resolve, list):
+ return self._result_resolution([object_to_resolve])
+
+ return object_to_resolve
+
+ def _result_resolution(self, result_list):
+ """
+ Performs in place automatic object resolution on a set of results
+ returned by cypher_query.
+
+ The function operates recursively in order to be able to resolve Nodes
+ within nested list structures. Not meant to be called directly,
+ used primarily by cypher_query.
+
+ :param result_list: A list of results as returned by cypher_query.
+ :type list:
+
+ :return: A list of instantiated objects.
+ """
+
+ # Object resolution occurs in-place
+ for a_result_item in enumerate(result_list):
+ for a_result_attribute in enumerate(a_result_item[1]):
+ try:
+ # Primitive types should remain primitive types,
+ # Nodes to be resolved to native objects
+ resolved_object = a_result_attribute[1]
+
+ resolved_object = self._object_resolution(resolved_object)
+
+ result_list[a_result_item[0]][
+ a_result_attribute[0]
+ ] = resolved_object
+
+ except KeyError as exc:
+ # Not being able to match the label set of a node with a known object results
+ # in a KeyError in the internal dictionary used for resolution. If it is impossible
+ # to match, then raise an exception with more details about the error.
+ if isinstance(a_result_attribute[1], Node):
+ raise NodeClassNotDefined(
+ a_result_attribute[1], self._NODE_CLASS_REGISTRY
+ ) from exc
+
+ if isinstance(a_result_attribute[1], Relationship):
+ raise RelationshipClassNotDefined(
+ a_result_attribute[1], self._NODE_CLASS_REGISTRY
+ ) from exc
+
+ return result_list
+
+ @ensure_connection
+ def cypher_query(
+ self,
+ query,
+ params=None,
+ handle_unique=True,
+ retry_on_session_expire=False,
+ resolve_objects=False,
+ ):
+ """
+ Runs a query on the database and returns a list of results and their headers.
+
+ :param query: A CYPHER query
+ :type: str
+ :param params: Dictionary of parameters
+ :type: dict
+ :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors
+ :type: bool
+ :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired.
+ If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance.
+ :type: bool
+ :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically
+ :type: bool
+
+ :return: A tuple containing a list of results and a tuple of headers.
+ """
+
+ if self._active_transaction:
+ # Use current session is a transaction is currently active
+ results, meta = self._run_cypher_query(
+ self._active_transaction,
+ query,
+ params,
+ handle_unique,
+ retry_on_session_expire,
+ resolve_objects,
+ )
+ else:
+ # Otherwise create a new session in a with to dispose of it after it has been run
+ with self.driver.session(
+ database=self._database_name, impersonated_user=self.impersonated_user
+ ) as session:
+ results, meta = self._run_cypher_query(
+ session,
+ query,
+ params,
+ handle_unique,
+ retry_on_session_expire,
+ resolve_objects,
+ )
+
+ return results, meta
+
+ def _run_cypher_query(
+ self,
+ session: Session,
+ query,
+ params,
+ handle_unique,
+ retry_on_session_expire,
+ resolve_objects,
+ ):
+ try:
+ # Retrieve the data
+ start = time.time()
+ response: Result = session.run(query, params)
+ results, meta = [list(r.values()) for r in response], response.keys()
+ end = time.time()
+
+ if resolve_objects:
+ # Do any automatic resolution required
+ results = self._result_resolution(results)
+
+ except ClientError as e:
+ if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed":
+ if "already exists with label" in e.message and handle_unique:
+ raise UniqueProperty(e.message) from e
+
+ raise ConstraintValidationFailed(e.message) from e
+ exc_info = sys.exc_info()
+ raise exc_info[1].with_traceback(exc_info[2])
+ except SessionExpired:
+ if retry_on_session_expire:
+ self.set_connection(url=self.url)
+ return self.cypher_query(
+ query=query,
+ params=params,
+ handle_unique=handle_unique,
+ retry_on_session_expire=False,
+ )
+ raise
+
+ tte = end - start
+ if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float(
+ os.environ.get("NEOMODEL_SLOW_QUERIES", 0)
+ ):
+ logger.debug(
+ "query: "
+ + query
+ + "\nparams: "
+ + repr(params)
+ + f"\ntook: {tte:.2g}s\n"
+ )
+
+ return results, meta
+
+ def get_id_method(self) -> str:
+ db_version = self.database_version
+ if db_version.startswith("4"):
+ return "id"
+ else:
+ return "elementId"
+
+ def parse_element_id(self, element_id: str):
+ db_version = self.database_version
+ return int(element_id) if db_version.startswith("4") else element_id
+
+ def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]:
+ """Returns all indexes existing in the database
+
+ Arguments:
+ exclude_token_lookup[bool]: Exclude automatically create token lookup indexes
+
+ Returns:
+ Sequence[dict]: List of dictionaries, each entry being an index definition
+ """
+ indexes, meta_indexes = self.cypher_query("SHOW INDEXES")
+ indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes]
+
+ if exclude_token_lookup:
+ indexes_as_dict = [
+ obj for obj in indexes_as_dict if obj["type"] != "LOOKUP"
+ ]
+
+ return indexes_as_dict
+
+ def list_constraints(self) -> Sequence[dict]:
+ """Returns all constraints existing in the database
+
+ Returns:
+ Sequence[dict]: List of dictionaries, each entry being a constraint definition
+ """
+ constraints, meta_constraints = self.cypher_query("SHOW CONSTRAINTS")
+ constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints]
+
+ return constraints_as_dict
+
+ @ensure_connection
+ def version_is_higher_than(self, version_tag: str) -> bool:
+ """Returns true if the database version is higher or equal to a given tag
+
+ Args:
+ version_tag (str): The version to compare against
+
+ Returns:
+ bool: True if the database version is higher or equal to the given version
+ """
+ db_version = self.database_version
+ return version_tag_to_integer(db_version) >= version_tag_to_integer(version_tag)
+
+ @ensure_connection
+ def edition_is_enterprise(self) -> bool:
+ """Returns true if the database edition is enterprise
+
+ Returns:
+ bool: True if the database edition is enterprise
+ """
+ edition = self.database_edition
+ return edition == "enterprise"
+
+ def change_neo4j_password(self, user, new_password):
+ self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'")
+
+ def clear_neo4j_database(self, clear_constraints=False, clear_indexes=False):
+ self.cypher_query(
+ """
+ MATCH (a)
+ CALL { WITH a DETACH DELETE a }
+ IN TRANSACTIONS OF 5000 rows
+ """
+ )
+ if clear_constraints:
+ drop_constraints()
+ if clear_indexes:
+ drop_indexes()
+
+ def drop_constraints(self, quiet=True, stdout=None):
+ """
+ Discover and drop all constraints.
+
+ :type: bool
+ :return: None
+ """
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ results, meta = self.cypher_query("SHOW CONSTRAINTS")
+
+ results_as_dict = [dict(zip(meta, row)) for row in results]
+ for constraint in results_as_dict:
+ self.cypher_query("DROP CONSTRAINT " + constraint["name"])
+ if not quiet:
+ stdout.write(
+ (
+ " - Dropping unique constraint and index"
+ f" on label {constraint['labelsOrTypes'][0]}"
+ f" with property {constraint['properties'][0]}.\n"
+ )
+ )
+ if not quiet:
+ stdout.write("\n")
+
+ def drop_indexes(self, quiet=True, stdout=None):
+ """
+ Discover and drop all indexes, except the automatically created token lookup indexes.
+
+ :type: bool
+ :return: None
+ """
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ indexes = self.list_indexes(exclude_token_lookup=True)
+ for index in indexes:
+ self.cypher_query("DROP INDEX " + index["name"])
+ if not quiet:
+ stdout.write(
+ f' - Dropping index on labels {",".join(index["labelsOrTypes"])} with properties {",".join(index["properties"])}.\n'
+ )
+ if not quiet:
+ stdout.write("\n")
+
+ def remove_all_labels(self, stdout=None):
+ """
+ Calls functions for dropping constraints and indexes.
+
+ :param stdout: output stream
+ :return: None
+ """
+
+ if not stdout:
+ stdout = sys.stdout
+
+ stdout.write("Dropping constraints...\n")
+ self.drop_constraints(quiet=False, stdout=stdout)
+
+ stdout.write("Dropping indexes...\n")
+ self.drop_indexes(quiet=False, stdout=stdout)
+
+ def install_all_labels(self, stdout=None):
+ """
+ Discover all subclasses of StructuredNode in your application and execute install_labels on each.
+ Note: code must be loaded (imported) in order for a class to be discovered.
+
+ :param stdout: output stream
+ :return: None
+ """
+
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ def subsub(cls): # recursively return all subclasses
+ subclasses = cls.__subclasses__()
+ if not subclasses: # base case: no more subclasses
+ return []
+ return subclasses + [g for s in cls.__subclasses__() for g in subsub(s)]
+
+ stdout.write("Setting up indexes and constraints...\n\n")
+
+ i = 0
+ for cls in subsub(StructuredNode):
+ stdout.write(f"Found {cls.__module__}.{cls.__name__}\n")
+ install_labels(cls, quiet=False, stdout=stdout)
+ i += 1
+
+ if i:
+ stdout.write("\n")
+
+ stdout.write(f"Finished {i} classes.\n")
+
+ def install_labels(self, cls, quiet=True, stdout=None):
+ """
+ Setup labels with indexes and constraints for a given class
+
+ :param cls: StructuredNode class
+ :type: class
+ :param quiet: (default true) enable standard output
+ :param stdout: stdout stream
+ :type: bool
+ :return: None
+ """
+ if not stdout or stdout is None:
+ stdout = sys.stdout
+
+ if not hasattr(cls, "__label__"):
+ if not quiet:
+ stdout.write(
+ f" ! Skipping class {cls.__module__}.{cls.__name__} is abstract\n"
+ )
+ return
+
+ for name, property in cls.defined_properties(aliases=False, rels=False).items():
+ self._install_node(cls, name, property, quiet, stdout)
+
+ for _, relationship in cls.defined_properties(
+ aliases=False, rels=True, properties=False
+ ).items():
+ self._install_relationship(cls, relationship, quiet, stdout)
+
+ def _create_node_index(self, label: str, property_name: str, stdout):
+ try:
+ self.cypher_query(
+ f"CREATE INDEX index_{label}_{property_name} FOR (n:{label}) ON (n.{property_name}); "
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ INDEX_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+
+ def _create_node_constraint(self, label: str, property_name: str, stdout):
+ try:
+ self.cypher_query(
+ f"""CREATE CONSTRAINT constraint_unique_{label}_{property_name}
+ FOR (n:{label}) REQUIRE n.{property_name} IS UNIQUE"""
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ CONSTRAINT_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+
+ def _create_relationship_index(
+ self, relationship_type: str, property_name: str, stdout
+ ):
+ try:
+ self.cypher_query(
+ f"CREATE INDEX index_{relationship_type}_{property_name} FOR ()-[r:{relationship_type}]-() ON (r.{property_name}); "
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ INDEX_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+
+ def _create_relationship_constraint(
+ self, relationship_type: str, property_name: str, stdout
+ ):
+ if self.version_is_higher_than("5.7"):
+ try:
+ self.cypher_query(
+ f"""CREATE CONSTRAINT constraint_unique_{relationship_type}_{property_name}
+ FOR ()-[r:{relationship_type}]-() REQUIRE r.{property_name} IS UNIQUE"""
+ )
+ except ClientError as e:
+ if e.code in (
+ RULE_ALREADY_EXISTS,
+ CONSTRAINT_ALREADY_EXISTS,
+ ):
+ stdout.write(f"{str(e)}\n")
+ else:
+ raise
+ else:
+ raise FeatureNotSupported(
+ f"Unique indexes on relationships are not supported in Neo4j version {self.database_version}. Please upgrade to Neo4j 5.7 or higher."
+ )
+
+ def _install_node(self, cls, name, property, quiet, stdout):
+ # Create indexes and constraints for node property
+ db_property = property.db_property or name
+ if property.index:
+ if not quiet:
+ stdout.write(
+ f" + Creating node index {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n"
+ )
+ self._create_node_index(
+ label=cls.__label__, property_name=db_property, stdout=stdout
+ )
+
+ elif property.unique_index:
+ if not quiet:
+ stdout.write(
+ f" + Creating node unique constraint for {name} on label {cls.__label__} for class {cls.__module__}.{cls.__name__}\n"
+ )
+ self._create_node_constraint(
+ label=cls.__label__, property_name=db_property, stdout=stdout
+ )
+
+ def _install_relationship(self, cls, relationship, quiet, stdout):
+ # Create indexes and constraints for relationship property
+ relationship_cls = relationship.definition["model"]
+ if relationship_cls is not None:
+ relationship_type = relationship.definition["relation_type"]
+ for prop_name, property in relationship_cls.defined_properties(
+ aliases=False, rels=False
+ ).items():
+ db_property = property.db_property or prop_name
+ if property.index:
+ if not quiet:
+ stdout.write(
+ f" + Creating relationship index {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n"
+ )
+ self._create_relationship_index(
+ relationship_type=relationship_type,
+ property_name=db_property,
+ stdout=stdout,
+ )
+ elif property.unique_index:
+ if not quiet:
+ stdout.write(
+ f" + Creating relationship unique constraint for {prop_name} on relationship type {relationship_type} for relationship model {cls.__module__}.{relationship_cls.__name__}\n"
+ )
+ self._create_relationship_constraint(
+ relationship_type=relationship_type,
+ property_name=db_property,
+ stdout=stdout,
+ )
+
+
+# Create a singleton instance of the database object
+db = Database()
+
+
+# Deprecated methods
+def change_neo4j_password(db: Database, user, new_password):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.change_neo4j_password(user, new_password) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.change_neo4j_password(user, new_password)
+
+
+def clear_neo4j_database(db: Database, clear_constraints=False, clear_indexes=False):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.clear_neo4j_database(clear_constraints, clear_indexes) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.clear_neo4j_database(clear_constraints, clear_indexes)
+
+
+def drop_constraints(quiet=True, stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.drop_constraints(quiet, stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.drop_constraints(quiet, stdout)
+
+
+def drop_indexes(quiet=True, stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.drop_indexes(quiet, stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.drop_indexes(quiet, stdout)
+
+
+def remove_all_labels(stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.remove_all_labels(stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.remove_all_labels(stdout)
+
+
+def install_labels(cls, quiet=True, stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.install_labels(cls, quiet, stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.install_labels(cls, quiet, stdout)
+
+
+def install_all_labels(stdout=None):
+ deprecated(
+ """
+ This method has been moved to the Database singleton (db for sync, db for async).
+ Please use db.install_all_labels(stdout) instead.
+ This direct call will be removed in an upcoming version.
+ """
+ )
+ db.install_all_labels(stdout)
+
+
+class TransactionProxy:
+ bookmarks: Optional[Bookmarks] = None
+
+ def __init__(self, db: Database, access_mode=None):
+ self.db = db
+ self.access_mode = access_mode
+
+ @ensure_connection
+ def __enter__(self):
+ print("aenter called")
+ self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks)
+ self.bookmarks = None
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ print("aexit called")
+ if exc_value:
+ self.db.rollback()
+
+ if (
+ exc_type is ClientError
+ and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed"
+ ):
+ raise UniqueProperty(exc_value.message)
+
+ if not exc_value:
+ self.last_bookmark = self.db.commit()
+
+ def __call__(self, func):
+ if Util.is_async_code and not iscoroutinefunction(func):
+ raise TypeError(NOT_COROUTINE_ERROR)
+
+ def wrapper(*args, **kwargs):
+ with self:
+ print("call called")
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ @property
+ def with_bookmark(self):
+ return BookmarkingAsyncTransactionProxy(self.db, self.access_mode)
+
+
+class BookmarkingAsyncTransactionProxy(TransactionProxy):
+ def __call__(self, func):
+ if Util.is_async_code and not iscoroutinefunction(func):
+ raise TypeError(NOT_COROUTINE_ERROR)
+
+ def wrapper(*args, **kwargs):
+ self.bookmarks = kwargs.pop("bookmarks", None)
+
+ with self:
+ result = func(*args, **kwargs)
+ self.last_bookmark = None
+
+ return result, self.last_bookmark
+
+ return wrapper
+
+
+class ImpersonationHandler:
+ def __init__(self, db: Database, impersonated_user: str):
+ self.db = db
+ self.impersonated_user = impersonated_user
+
+ def __enter__(self):
+ self.db.impersonated_user = self.impersonated_user
+ return self
+
+ def __exit__(self, exception_type, exception_value, exception_traceback):
+ self.db.impersonated_user = None
+
+ print("\nException type:", exception_type)
+ print("\nException value:", exception_value)
+ print("\nTraceback:", exception_traceback)
+
+ def __call__(self, func):
+ def wrapper(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class NodeMeta(type):
+ def __new__(mcs, name, bases, namespace):
+ namespace["DoesNotExist"] = type(name + "DoesNotExist", (DoesNotExist,), {})
+ cls = super().__new__(mcs, name, bases, namespace)
+ cls.DoesNotExist._model_class = cls
+
+ if hasattr(cls, "__abstract_node__"):
+ delattr(cls, "__abstract_node__")
+ else:
+ if "deleted" in namespace:
+ raise ValueError(
+ "Property name 'deleted' is not allowed as it conflicts with neomodel internals."
+ )
+ elif "id" in namespace:
+ raise ValueError(
+ """
+ Property name 'id' is not allowed as it conflicts with neomodel internals.
+ Consider using 'uid' or 'identifier' as id is also a Neo4j internal.
+ """
+ )
+ elif "element_id" in namespace:
+ raise ValueError(
+ """
+ Property name 'element_id' is not allowed as it conflicts with neomodel internals.
+ Consider using 'uid' or 'identifier' as element_id is also a Neo4j internal.
+ """
+ )
+ for key, value in (
+ (x, y) for x, y in namespace.items() if isinstance(y, Property)
+ ):
+ value.name, value.owner = key, cls
+ if hasattr(value, "setup") and callable(value.setup):
+ value.setup()
+
+ # cache various groups of properies
+ cls.__required_properties__ = tuple(
+ name
+ for name, property in cls.defined_properties(
+ aliases=False, rels=False
+ ).items()
+ if property.required or property.unique_index
+ )
+ cls.__all_properties__ = tuple(
+ cls.defined_properties(aliases=False, rels=False).items()
+ )
+ cls.__all_aliases__ = tuple(
+ cls.defined_properties(properties=False, rels=False).items()
+ )
+ cls.__all_relationships__ = tuple(
+ cls.defined_properties(aliases=False, properties=False).items()
+ )
+
+ cls.__label__ = namespace.get("__label__", name)
+ cls.__optional_labels__ = namespace.get("__optional_labels__", [])
+
+ build_class_registry(cls)
+
+ return cls
+
+
+def build_class_registry(cls):
+ base_label_set = frozenset(cls.inherited_labels())
+ optional_label_set = set(cls.inherited_optional_labels())
+
+ # Construct all possible combinations of labels + optional labels
+ possible_label_combinations = [
+ frozenset(set(x).union(base_label_set))
+ for i in range(1, len(optional_label_set) + 1)
+ for x in combinations(optional_label_set, i)
+ ]
+ possible_label_combinations.append(base_label_set)
+
+ for label_set in possible_label_combinations:
+ if label_set not in db._NODE_CLASS_REGISTRY:
+ db._NODE_CLASS_REGISTRY[label_set] = cls
+ else:
+ raise NodeClassAlreadyDefined(cls, db._NODE_CLASS_REGISTRY)
+
+
+NodeBase = NodeMeta("NodeBase", (PropertyManager,), {"__abstract_node__": True})
+
+
+class StructuredNode(NodeBase):
+ """
+ Base class for all node definitions to inherit from.
+
+ If you want to create your own abstract classes set:
+ __abstract_node__ = True
+ """
+
+ # static properties
+
+ __abstract_node__ = True
+
+ # magic methods
+
+ def __init__(self, *args, **kwargs):
+ if "deleted" in kwargs:
+ raise ValueError("deleted property is reserved for neomodel")
+
+ for key, val in self.__all_relationships__:
+ self.__dict__[key] = val.build_manager(self, key)
+
+ super().__init__(*args, **kwargs)
+
+ def __eq__(self, other):
+ if not isinstance(other, (StructuredNode,)):
+ return False
+ if hasattr(self, "element_id") and hasattr(other, "element_id"):
+ return self.element_id == other.element_id
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __repr__(self):
+ return f"<{self.__class__.__name__}: {self}>"
+
+ def __str__(self):
+ return repr(self.__properties__)
+
+ # dynamic properties
+
+ @classproperty
+ def nodes(cls):
+ """
+ Returns a NodeSet object representing all nodes of the classes label
+ :return: NodeSet
+ :rtype: NodeSet
+ """
+ from neomodel.sync_.match import NodeSet
+
+ return NodeSet(cls)
+
+ @property
+ def element_id(self):
+ if hasattr(self, "element_id_property"):
+ return self.element_id_property
+ return None
+
+ # Version 4.4 support - id is deprecated in version 5.x
+ @property
+ def id(self):
+ try:
+ return int(self.element_id_property)
+ except (TypeError, ValueError):
+ raise ValueError(
+ "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
+ )
+
+ # methods
+
+ @classmethod
+ def _build_merge_query(
+ cls, merge_params, update_existing=False, lazy=False, relationship=None
+ ):
+ """
+ Get a tuple of a CYPHER query and a params dict for the specified MERGE query.
+
+ :param merge_params: The target node match parameters, each node must have a "create" key and optional "update".
+ :type merge_params: list of dict
+ :param update_existing: True to update properties of existing nodes, default False to keep existing values.
+ :type update_existing: bool
+ :rtype: tuple
+ """
+ query_params = dict(merge_params=merge_params)
+ n_merge_labels = ":".join(cls.inherited_labels())
+ n_merge_prm = ", ".join(
+ (
+ f"{getattr(cls, p).db_property or p}: params.create.{getattr(cls, p).db_property or p}"
+ for p in cls.__required_properties__
+ )
+ )
+ n_merge = f"n:{n_merge_labels} {{{n_merge_prm}}}"
+ if relationship is None:
+ # create "simple" unwind query
+ query = f"UNWIND $merge_params as params\n MERGE ({n_merge})\n "
+ else:
+ # validate relationship
+ if not isinstance(relationship.source, StructuredNode):
+ raise ValueError(
+ f"relationship source [{repr(relationship.source)}] is not a StructuredNode"
+ )
+ relation_type = relationship.definition.get("relation_type")
+ if not relation_type:
+ raise ValueError(
+ "No relation_type is specified on provided relationship"
+ )
+
+ from neomodel.sync_.match import _rel_helper
+
+ query_params["source_id"] = db.parse_element_id(
+ relationship.source.element_id
+ )
+ query = f"MATCH (source:{relationship.source.__label__}) WHERE {db.get_id_method()}(source) = $source_id\n "
+ query += "WITH source\n UNWIND $merge_params as params \n "
+ query += "MERGE "
+ query += _rel_helper(
+ lhs="source",
+ rhs=n_merge,
+ ident=None,
+ relation_type=relation_type,
+ direction=relationship.definition["direction"],
+ )
+
+ query += "ON CREATE SET n = params.create\n "
+ # if update_existing, write properties on match as well
+ if update_existing is True:
+ query += "ON MATCH SET n += params.update\n"
+
+ # close query
+ if lazy:
+ query += f"RETURN {db.get_id_method()}(n)"
+ else:
+ query += "RETURN n"
+
+ return query, query_params
+
+ @classmethod
+ def create(cls, *props, **kwargs):
+ """
+ Call to CREATE with parameters map. A new instance will be created and saved.
+
+ :param props: dict of properties to create the nodes.
+ :type props: tuple
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :type: bool
+ :rtype: list
+ """
+
+ if "streaming" in kwargs:
+ warnings.warn(
+ STREAMING_WARNING,
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+
+ lazy = kwargs.get("lazy", False)
+ # create mapped query
+ query = f"CREATE (n:{':'.join(cls.inherited_labels())} $create_params)"
+
+ # close query
+ if lazy:
+ query += f" RETURN {db.get_id_method()}(n)"
+ else:
+ query += " RETURN n"
+
+ results = []
+ for item in [
+ cls.deflate(p, obj=_UnsavedNode(), skip_empty=True) for p in props
+ ]:
+ node, _ = db.cypher_query(query, {"create_params": item})
+ results.extend(node[0])
+
+ nodes = [cls.inflate(node) for node in results]
+
+ if not lazy and hasattr(cls, "post_create"):
+ for node in nodes:
+ node.post_create()
+
+ return nodes
+
+ @classmethod
+ def create_or_update(cls, *props, **kwargs):
+ """
+ Call to MERGE with parameters map. A new instance will be created and saved if does not already exists,
+ this is an atomic operation. If an instance already exists all optional properties specified will be updated.
+
+ Note that the post_create hook isn't called after create_or_update
+
+ :param props: List of dict arguments to get or create the entities with.
+ :type props: tuple
+ :param relationship: Optional, relationship to get/create on when new entity is created.
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :rtype: list
+ """
+ lazy = kwargs.get("lazy", False)
+ relationship = kwargs.get("relationship")
+
+ # build merge query, make sure to update only explicitly specified properties
+ create_or_update_params = []
+ for specified, deflated in [
+ (p, cls.deflate(p, skip_empty=True)) for p in props
+ ]:
+ create_or_update_params.append(
+ {
+ "create": deflated,
+ "update": dict(
+ (k, v) for k, v in deflated.items() if k in specified
+ ),
+ }
+ )
+ query, params = cls._build_merge_query(
+ create_or_update_params,
+ update_existing=True,
+ relationship=relationship,
+ lazy=lazy,
+ )
+
+ if "streaming" in kwargs:
+ warnings.warn(
+ STREAMING_WARNING,
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+
+ # fetch and build instance for each result
+ results = db.cypher_query(query, params)
+ return [cls.inflate(r[0]) for r in results[0]]
+
+ def cypher(self, query, params=None):
+ """
+ Execute a cypher query with the param 'self' pre-populated with the nodes neo4j id.
+
+ :param query: cypher query string
+ :type: string
+ :param params: query parameters
+ :type: dict
+ :return: list containing query results
+ :rtype: list
+ """
+ self._pre_action_check("cypher")
+ params = params or {}
+ element_id = db.parse_element_id(self.element_id)
+ params.update({"self": element_id})
+ return db.cypher_query(query, params)
+
+ @hooks
+ def delete(self):
+ """
+ Delete a node and its relationships
+
+ :return: True
+ """
+ self._pre_action_check("delete")
+ self.cypher(
+ f"MATCH (self) WHERE {db.get_id_method()}(self)=$self DETACH DELETE self"
+ )
+ delattr(self, "element_id_property")
+ self.deleted = True
+ return True
+
+ @classmethod
+ def get_or_create(cls, *props, **kwargs):
+ """
+ Call to MERGE with parameters map. A new instance will be created and saved if does not already exist,
+ this is an atomic operation.
+ Parameters must contain all required properties, any non required properties with defaults will be generated.
+
+ Note that the post_create hook isn't called after get_or_create
+
+ :param props: Arguments to get_or_create as tuple of dict with property names and values to get or create
+ the entities with.
+ :type props: tuple
+ :param relationship: Optional, relationship to get/create on when new entity is created.
+ :param lazy: False by default, specify True to get nodes with id only without the parameters.
+ :rtype: list
+ """
+ lazy = kwargs.get("lazy", False)
+ relationship = kwargs.get("relationship")
+
+ # build merge query
+ get_or_create_params = [
+ {"create": cls.deflate(p, skip_empty=True)} for p in props
+ ]
+ query, params = cls._build_merge_query(
+ get_or_create_params, relationship=relationship, lazy=lazy
+ )
+
+ if "streaming" in kwargs:
+ warnings.warn(
+ STREAMING_WARNING,
+ category=DeprecationWarning,
+ stacklevel=1,
+ )
+
+ # fetch and build instance for each result
+ results = db.cypher_query(query, params)
+ return [cls.inflate(r[0]) for r in results[0]]
+
+ @classmethod
+ def inflate(cls, node):
+ """
+ Inflate a raw neo4j_driver node to a neomodel node
+ :param node:
+ :return: node object
+ """
+ # support lazy loading
+ if isinstance(node, str) or isinstance(node, int):
+ snode = cls()
+ snode.element_id_property = node
+ else:
+ node_properties = _get_node_properties(node)
+ props = {}
+ for key, prop in cls.__all_properties__:
+ # map property name from database to object property
+ db_property = prop.db_property or key
+
+ if db_property in node_properties:
+ props[key] = prop.inflate(node_properties[db_property], node)
+ elif prop.has_default:
+ props[key] = prop.default_value()
+ else:
+ props[key] = None
+
+ snode = cls(**props)
+ snode.element_id_property = node.element_id
+
+ return snode
+
+ @classmethod
+ def inherited_labels(cls):
+ """
+ Return list of labels from nodes class hierarchy.
+
+ :return: list
+ """
+ return [
+ scls.__label__
+ for scls in cls.mro()
+ if hasattr(scls, "__label__") and not hasattr(scls, "__abstract_node__")
+ ]
+
+ @classmethod
+ def inherited_optional_labels(cls):
+ """
+ Return list of optional labels from nodes class hierarchy.
+
+ :return: list
+ :rtype: list
+ """
+ return [
+ label
+ for scls in cls.mro()
+ for label in getattr(scls, "__optional_labels__", [])
+ if not hasattr(scls, "__abstract_node__")
+ ]
+
+ def labels(self):
+ """
+ Returns list of labels tied to the node from neo4j.
+
+ :return: list of labels
+ :rtype: list
+ """
+ self._pre_action_check("labels")
+ result = self.cypher(
+ f"MATCH (n) WHERE {db.get_id_method()}(n)=$self " "RETURN labels(n)"
+ )
+ return result[0][0][0]
+
+ def _pre_action_check(self, action):
+ if hasattr(self, "deleted") and self.deleted:
+ raise ValueError(
+ f"{self.__class__.__name__}.{action}() attempted on deleted node"
+ )
+ if not hasattr(self, "element_id"):
+ raise ValueError(
+ f"{self.__class__.__name__}.{action}() attempted on unsaved node"
+ )
+
+ def refresh(self):
+ """
+ Reload the node from neo4j
+ """
+ self._pre_action_check("refresh")
+ if hasattr(self, "element_id"):
+ results = self.cypher(
+ f"MATCH (n) WHERE {db.get_id_method()}(n)=$self RETURN n"
+ )
+ request = results[0]
+ if not request or not request[0]:
+ raise self.__class__.DoesNotExist("Can't refresh non existent node")
+ node = self.inflate(request[0][0])
+ for key, val in node.__properties__.items():
+ setattr(self, key, val)
+ else:
+ raise ValueError("Can't refresh unsaved node")
+
+ @hooks
+ def save(self):
+ """
+ Save the node to neo4j or raise an exception
+
+ :return: the node instance
+ """
+
+ # create or update instance node
+ if hasattr(self, "element_id_property"):
+ # update
+ params = self.deflate(self.__properties__, self)
+ query = f"MATCH (n) WHERE {db.get_id_method()}(n)=$self\n"
+
+ if params:
+ query += "SET "
+ query += ",\n".join([f"n.{key} = ${key}" for key in params])
+ query += "\n"
+ if self.inherited_labels():
+ query += "\n".join(
+ [f"SET n:`{label}`" for label in self.inherited_labels()]
+ )
+ self.cypher(query, params)
+ elif hasattr(self, "deleted") and self.deleted:
+ raise ValueError(
+ f"{self.__class__.__name__}.save() attempted on deleted node"
+ )
+ else: # create
+ result = self.create(self.__properties__)
+ created_node = result[0]
+ self.element_id_property = created_node.element_id
+ return self
diff --git a/neomodel/match.py b/neomodel/sync_/match.py
similarity index 95%
rename from neomodel/match.py
rename to neomodel/sync_/match.py
index fb47f568..fad8b05f 100644
--- a/neomodel/match.py
+++ b/neomodel/sync_/match.py
@@ -4,12 +4,11 @@
from dataclasses import dataclass
from typing import Optional
-from .core import StructuredNode, db
-from .exceptions import MultipleNodesReturned
-from .match_q import Q, QBase
-from .properties import AliasProperty
-
-OUTGOING, INCOMING, EITHER = 1, -1, 0
+from neomodel.exceptions import MultipleNodesReturned
+from neomodel.match_q import Q, QBase
+from neomodel.properties import AliasProperty
+from neomodel.sync_.core import StructuredNode, db
+from neomodel.util import INCOMING, OUTGOING
def _rel_helper(
@@ -505,7 +504,7 @@ def build_node(self, node):
_node_lookup = f"MATCH ({ident}) WHERE {db.get_id_method()}({ident})=${place_holder} WITH {ident}"
self._ast.lookup = _node_lookup
- self._query_params[place_holder] = node.element_id
+ self._query_params[place_holder] = db.parse_element_id(node.element_id)
self._ast.return_clause = ident
self._ast.result_class = node.__class__
@@ -732,24 +731,42 @@ def all(self, lazy=False):
:return: list of nodes
:rtype: list
"""
- return self.query_cls(self).build_ast()._execute(lazy)
+ ast = self.query_cls(self).build_ast()
+ return ast._execute(lazy)
def __iter__(self):
- return (i for i in self.query_cls(self).build_ast()._execute())
+ ast = self.query_cls(self).build_ast()
+ for i in ast._execute():
+ yield i
def __len__(self):
- return self.query_cls(self).build_ast()._count()
+ ast = self.query_cls(self).build_ast()
+ return ast._count()
def __bool__(self):
- return self.query_cls(self).build_ast()._count() > 0
+ """
+ Override for __bool__ dunder method.
+ :return: True if the set contains any nodes, False otherwise
+ :rtype: bool
+ """
+ ast = self.query_cls(self).build_ast()
+ _count = ast._count()
+ return _count > 0
def __nonzero__(self):
- return self.query_cls(self).build_ast()._count() > 0
+ """
+ Override for __bool__ dunder method.
+ :return: True if the set contains any node, False otherwise
+ :rtype: bool
+ """
+ return self.__bool__()
def __contains__(self, obj):
if isinstance(obj, StructuredNode):
if hasattr(obj, "element_id") and obj.element_id is not None:
- return self.query_cls(self).build_ast()._contains(obj.element_id)
+ ast = self.query_cls(self).build_ast()
+ obj_element_id = db.parse_element_id(obj.element_id)
+ return ast._contains(obj_element_id)
raise ValueError("Unsaved node: " + repr(obj))
raise ValueError("Expecting StructuredNode instance")
@@ -770,7 +787,9 @@ def __getitem__(self, key):
self.skip = key
self.limit = 1
- return self.query_cls(self).build_ast()._execute()[0]
+ ast = self.query_cls(self).build_ast()
+ _items = ast._execute()
+ return _items[0]
return None
@@ -810,11 +829,15 @@ def __init__(self, source):
self.relations_to_fetch: list = []
+ def __await__(self):
+ return self.all().__await__()
+
def _get(self, limit=None, lazy=False, **kwargs):
self.filter(**kwargs)
if limit:
self.limit = limit
- return self.query_cls(self).build_ast()._execute(lazy)
+ ast = self.query_cls(self).build_ast()
+ return ast._execute(lazy)
def get(self, lazy=False, **kwargs):
"""
@@ -849,7 +872,7 @@ def first(self, **kwargs):
:param kwargs: same syntax as `filter()`
:return: node
"""
- result = result = self._get(limit=1, **kwargs)
+ result = self._get(limit=1, **kwargs)
if result:
return result[0]
else:
@@ -986,6 +1009,9 @@ class Traversal(BaseSet):
:type defintion: :class:`dict`
"""
+ def __await__(self):
+ return self.all().__await__()
+
def __init__(self, source, name, definition):
"""
Create a traversal
diff --git a/neomodel/path.py b/neomodel/sync_/path.py
similarity index 88%
rename from neomodel/path.py
rename to neomodel/sync_/path.py
index 5f063d11..62a49fe7 100644
--- a/neomodel/path.py
+++ b/neomodel/sync_/path.py
@@ -1,14 +1,14 @@
from neo4j.graph import Path
-from .core import db
-from .relationship import StructuredRel
-from .exceptions import RelationshipClassNotDefined
+
+from neomodel.sync_.core import db
+from neomodel.sync_.relationship import StructuredRel
class NeomodelPath(Path):
"""
Represents paths within neomodel.
- This object is instantiated when you include whole paths in your ``cypher_query()``
+ This object is instantiated when you include whole paths in your ``cypher_query()``
result sets and turn ``resolve_objects`` to True.
That is, any query of the form:
@@ -16,7 +16,7 @@ class NeomodelPath(Path):
MATCH p=(:SOME_NODE_LABELS)-[:SOME_REL_LABELS]-(:SOME_OTHER_NODE_LABELS) return p
- ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already
+ ``NeomodelPath`` are simple objects that reference their nodes and relationships, each of which is already
resolved to their neomodel objects if such mapping is possible.
@@ -25,8 +25,9 @@ class NeomodelPath(Path):
:type nodes: List[StructuredNode]
:type relationships: List[StructuredRel]
"""
+
def __init__(self, a_neopath):
- self._nodes=[]
+ self._nodes = []
self._relationships = []
for a_node in a_neopath.nodes:
@@ -42,6 +43,7 @@ def __init__(self, a_neopath):
else:
new_rel = StructuredRel.inflate(a_relationship)
self._relationships.append(new_rel)
+
@property
def nodes(self):
return self._nodes
@@ -49,5 +51,3 @@ def nodes(self):
@property
def relationships(self):
return self._relationships
-
-
diff --git a/neomodel/sync_/property_manager.py b/neomodel/sync_/property_manager.py
new file mode 100644
index 00000000..85452f0b
--- /dev/null
+++ b/neomodel/sync_/property_manager.py
@@ -0,0 +1,109 @@
+import types
+
+from neomodel.exceptions import RequiredProperty
+from neomodel.properties import AliasProperty, Property
+
+
+def display_for(key):
+ def display_choice(self):
+ return getattr(self.__class__, key).choices[getattr(self, key)]
+
+ return display_choice
+
+
+class PropertyManager:
+ """
+ Common methods for handling properties on node and relationship objects.
+ """
+
+ def __init__(self, **kwargs):
+ properties = getattr(self, "__all_properties__", None)
+ if properties is None:
+ properties = self.defined_properties(rels=False, aliases=False).items()
+ for name, property in properties:
+ if kwargs.get(name) is None:
+ if getattr(property, "has_default", False):
+ setattr(self, name, property.default_value())
+ else:
+ setattr(self, name, None)
+ else:
+ setattr(self, name, kwargs[name])
+
+ if getattr(property, "choices", None):
+ setattr(
+ self,
+ f"get_{name}_display",
+ types.MethodType(display_for(name), self),
+ )
+
+ if name in kwargs:
+ del kwargs[name]
+
+ aliases = getattr(self, "__all_aliases__", None)
+ if aliases is None:
+ aliases = self.defined_properties(
+ aliases=True, rels=False, properties=False
+ ).items()
+ for name, property in aliases:
+ if name in kwargs:
+ setattr(self, name, kwargs[name])
+ del kwargs[name]
+
+ # undefined properties (for magic @prop.setters etc)
+ for name, property in kwargs.items():
+ setattr(self, name, property)
+
+ @property
+ def __properties__(self):
+ from neomodel.sync_.relationship_manager import RelationshipManager
+
+ return dict(
+ (name, value)
+ for name, value in vars(self).items()
+ if not name.startswith("_")
+ and not callable(value)
+ and not isinstance(
+ value,
+ (
+ RelationshipManager,
+ AliasProperty,
+ ),
+ )
+ )
+
+ @classmethod
+ def deflate(cls, properties, obj=None, skip_empty=False):
+ # deflate dict ready to be stored
+ deflated = {}
+ for name, property in cls.defined_properties(aliases=False, rels=False).items():
+ db_property = property.db_property or name
+ if properties.get(name) is not None:
+ deflated[db_property] = property.deflate(properties[name], obj)
+ elif property.has_default:
+ deflated[db_property] = property.deflate(property.default_value(), obj)
+ elif property.required:
+ raise RequiredProperty(name, cls)
+ elif not skip_empty:
+ deflated[db_property] = None
+ return deflated
+
+ @classmethod
+ def defined_properties(cls, aliases=True, properties=True, rels=True):
+ from neomodel.sync_.relationship_manager import RelationshipDefinition
+
+ props = {}
+ for baseclass in reversed(cls.__mro__):
+ props.update(
+ dict(
+ (name, property)
+ for name, property in vars(baseclass).items()
+ if (aliases and isinstance(property, AliasProperty))
+ or (
+ properties
+ and isinstance(property, Property)
+ and not isinstance(property, AliasProperty)
+ )
+ or (rels and isinstance(property, RelationshipDefinition))
+ )
+ )
+ return props
diff --git a/neomodel/relationship.py b/neomodel/sync_/relationship.py
similarity index 71%
rename from neomodel/relationship.py
rename to neomodel/sync_/relationship.py
index 8df56c47..0a199575 100644
--- a/neomodel/relationship.py
+++ b/neomodel/sync_/relationship.py
@@ -1,8 +1,9 @@
-import warnings
+from neomodel.hooks import hooks
+from neomodel.properties import Property
+from neomodel.sync_.core import db
+from neomodel.sync_.property_manager import PropertyManager
-from .core import db
-from .hooks import hooks
-from .properties import Property, PropertyManager
+ELEMENT_ID_MIGRATION_NOTICE = "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
class RelationshipMeta(type):
@@ -50,57 +51,42 @@ def __init__(self, *args, **kwargs):
@property
def element_id(self):
- return (
- int(self.element_id_property)
- if db.database_version.startswith("4")
- else self.element_id_property
- )
+ if hasattr(self, "element_id_property"):
+ return self.element_id_property
@property
def _start_node_element_id(self):
- return (
- int(self._start_node_element_id_property)
- if db.database_version.startswith("4")
- else self._start_node_element_id_property
- )
+ if hasattr(self, "_start_node_element_id_property"):
+ return self._start_node_element_id_property
@property
def _end_node_element_id(self):
- return (
- int(self._end_node_element_id_property)
- if db.database_version.startswith("4")
- else self._end_node_element_id_property
- )
+ if hasattr(self, "_end_node_element_id_property"):
+ return self._end_node_element_id_property
# Version 4.4 support - id is deprecated in version 5.x
@property
def id(self):
try:
return int(self.element_id_property)
- except (TypeError, ValueError):
- raise ValueError(
- "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
- )
+ except (TypeError, ValueError) as exc:
+ raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc
# Version 4.4 support - id is deprecated in version 5.x
@property
def _start_node_id(self):
try:
return int(self._start_node_element_id_property)
- except (TypeError, ValueError):
- raise ValueError(
- "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
- )
+ except (TypeError, ValueError) as exc:
+ raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc
# Version 4.4 support - id is deprecated in version 5.x
@property
def _end_node_id(self):
try:
return int(self._end_node_element_id_property)
- except (TypeError, ValueError):
- raise ValueError(
- "id is deprecated in Neo4j version 5, please migrate to element_id. If you use the id in a Cypher query, replace id() by elementId()."
- )
+ except (TypeError, ValueError) as exc:
+ raise ValueError(ELEMENT_ID_MIGRATION_NOTICE) from exc
@hooks
def save(self):
@@ -112,7 +98,7 @@ def save(self):
props = self.deflate(self.__properties__)
query = f"MATCH ()-[r]->() WHERE {db.get_id_method()}(r)=$self "
query += "".join([f" SET r.{key} = ${key}" for key in props])
- props["self"] = self.element_id
+ props["self"] = db.parse_element_id(self.element_id)
db.cypher_query(query, props)
@@ -124,16 +110,16 @@ def start_node(self):
:return: StructuredNode
"""
- test = db.cypher_query(
+ results = db.cypher_query(
f"""
MATCH (aNode)
WHERE {db.get_id_method()}(aNode)=$start_node_element_id
RETURN aNode
""",
- {"start_node_element_id": self._start_node_element_id},
+ {"start_node_element_id": db.parse_element_id(self._start_node_element_id)},
resolve_objects=True,
)
- return test[0][0][0]
+ return results[0][0][0]
def end_node(self):
"""
@@ -141,15 +127,16 @@ def end_node(self):
:return: StructuredNode
"""
- return db.cypher_query(
+ results = db.cypher_query(
f"""
MATCH (aNode)
WHERE {db.get_id_method()}(aNode)=$end_node_element_id
RETURN aNode
""",
- {"end_node_element_id": self._end_node_element_id},
+ {"end_node_element_id": db.parse_element_id(self._end_node_element_id)},
resolve_objects=True,
- )[0][0][0]
+ )
+ return results[0][0][0]
@classmethod
def inflate(cls, rel):
diff --git a/neomodel/relationship_manager.py b/neomodel/sync_/relationship_manager.py
similarity index 93%
rename from neomodel/relationship_manager.py
rename to neomodel/sync_/relationship_manager.py
index 1e9cf79e..8b8ae96a 100644
--- a/neomodel/relationship_manager.py
+++ b/neomodel/sync_/relationship_manager.py
@@ -3,19 +3,17 @@
import sys
from importlib import import_module
-from .core import db
-from .exceptions import NotConnected, RelationshipClassRedefined
-from .match import (
+from neomodel.exceptions import NotConnected, RelationshipClassRedefined
+from neomodel.sync_.core import db
+from neomodel.sync_.match import NodeSet, Traversal, _rel_helper, _rel_merge_helper
+from neomodel.sync_.relationship import StructuredRel
+from neomodel.util import (
EITHER,
INCOMING,
OUTGOING,
- NodeSet,
- Traversal,
- _rel_helper,
- _rel_merge_helper,
+ _get_node_properties,
+ enumerate_traceback,
)
-from .relationship import StructuredRel
-from .util import _get_node_properties, enumerate_traceback
# basestring python 3.x fallback
try:
@@ -66,6 +64,9 @@ def __str__(self):
return f"{self.description} in {direction} direction of type {self.definition['relation_type']} on node ({self.source.element_id}) of class '{self.source_class.__name__}'"
+ def __await__(self):
+ return self.all().__await__()
+
def _check_node(self, obj):
"""check for valid node i.e correct class and is saved"""
if not issubclass(type(obj), self.definition["node_class"]):
@@ -124,13 +125,14 @@ def connect(self, node, properties=None):
"MERGE" + new_rel
)
- params["them"] = node.element_id
+ params["them"] = db.parse_element_id(node.element_id)
if not rel_model:
self.source.cypher(q, params)
return True
- rel_ = self.source.cypher(q + " RETURN r", params)[0][0][0]
+ results = self.source.cypher(q + " RETURN r", params)
+ rel_ = results[0][0][0]
rel_instance = self._set_start_end_cls(rel_model.inflate(rel_), node)
if hasattr(rel_instance, "post_save"):
@@ -166,7 +168,8 @@ def relationship(self, node):
+ my_rel
+ f" WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r LIMIT 1"
)
- rels = self.source.cypher(q, {"them": node.element_id})[0]
+ results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)})
+ rels = results[0]
if not rels:
return
@@ -186,7 +189,8 @@ def all_relationships(self, node):
my_rel = _rel_helper(lhs="us", rhs="them", ident="r", **self.definition)
q = f"MATCH {my_rel} WHERE {db.get_id_method()}(them)=$them and {db.get_id_method()}(us)=$self RETURN r "
- rels = self.source.cypher(q, {"them": node.element_id})[0]
+ results = self.source.cypher(q, {"them": db.parse_element_id(node.element_id)})
+ rels = results[0]
if not rels:
return []
@@ -223,12 +227,14 @@ def reconnect(self, old_node, new_node):
old_rel = _rel_helper(lhs="us", rhs="old", ident="r", **self.definition)
# get list of properties on the existing rel
+ old_node_element_id = db.parse_element_id(old_node.element_id)
+ new_node_element_id = db.parse_element_id(new_node.element_id)
result, _ = self.source.cypher(
f"""
MATCH (us), (old) WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(old)=$old
MATCH {old_rel} RETURN r
""",
- {"old": old_node.element_id},
+ {"old": old_node_element_id},
)
if result:
node_properties = _get_node_properties(result[0][0])
@@ -249,7 +255,7 @@ def reconnect(self, old_node, new_node):
q += "".join([f" SET r2.{prop} = r.{prop}" for prop in existing_properties])
q += " WITH r DELETE r"
- self.source.cypher(q, {"old": old_node.element_id, "new": new_node.element_id})
+ self.source.cypher(q, {"old": old_node_element_id, "new": new_node_element_id})
@check_source
def disconnect(self, node):
@@ -264,7 +270,7 @@ def disconnect(self, node):
MATCH (a), (b) WHERE {db.get_id_method()}(a)=$self and {db.get_id_method()}(b)=$them
MATCH {rel} DELETE r
"""
- self.source.cypher(q, {"them": node.element_id})
+ self.source.cypher(q, {"them": db.parse_element_id(node.element_id)})
@check_source
def disconnect_all(self):
@@ -345,7 +351,8 @@ def single(self):
:return: StructuredNode
"""
try:
- return self[0]
+ rels = self
+ return rels[0]
except IndexError:
pass
diff --git a/neomodel/util.py b/neomodel/util.py
index 74e88250..1b88e407 100644
--- a/neomodel/util.py
+++ b/neomodel/util.py
@@ -1,630 +1,6 @@
-import logging
-import os
-import sys
-import time
import warnings
-from threading import local
-from typing import Optional, Sequence
-from urllib.parse import quote, unquote, urlparse
-from neo4j import DEFAULT_DATABASE, Driver, GraphDatabase, basic_auth
-from neo4j.api import Bookmarks
-from neo4j.exceptions import ClientError, ServiceUnavailable, SessionExpired
-from neo4j.graph import Node, Path, Relationship
-
-from neomodel import config, core
-from neomodel.exceptions import (
- ConstraintValidationFailed,
- FeatureNotSupported,
- NodeClassNotDefined,
- RelationshipClassNotDefined,
- UniqueProperty,
-)
-
-logger = logging.getLogger(__name__)
-
-
-# make sure the connection url has been set prior to executing the wrapped function
-def ensure_connection(func):
- def wrapper(self, *args, **kwargs):
- # Sort out where to find url
- if hasattr(self, "db"):
- _db = self.db
- else:
- _db = self
-
- if not _db.driver:
- if hasattr(config, "DRIVER") and config.DRIVER:
- _db.set_connection(driver=config.DRIVER)
- elif config.DATABASE_URL:
- _db.set_connection(url=config.DATABASE_URL)
-
- return func(self, *args, **kwargs)
-
- return wrapper
-
-
-def change_neo4j_password(db, user, new_password):
- db.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'")
-
-
-def clear_neo4j_database(db, clear_constraints=False, clear_indexes=False):
- db.cypher_query(
- """
- MATCH (a)
- CALL { WITH a DETACH DELETE a }
- IN TRANSACTIONS OF 5000 rows
- """
- )
- if clear_constraints:
- core.drop_constraints()
- if clear_indexes:
- core.drop_indexes()
-
-
-class Database(local):
- """
- A singleton object via which all operations from neomodel to the Neo4j backend are handled with.
- """
-
- _NODE_CLASS_REGISTRY = {}
-
- def __init__(self):
- self._active_transaction = None
- self.url = None
- self.driver = None
- self._session = None
- self._pid = None
- self._database_name = DEFAULT_DATABASE
- self.protocol_version = None
- self._database_version = None
- self._database_edition = None
- self.impersonated_user = None
-
- def set_connection(self, url: str = None, driver: Driver = None):
- """
- Sets the connection up and relevant internal. This can be done using a Neo4j URL or a driver instance.
-
- Args:
- url (str): Optionally, Neo4j URL in the form protocol://username:password@hostname:port/dbname.
- When provided, a Neo4j driver instance will be created by neomodel.
-
- driver (neo4j.Driver): Optionally, a pre-created driver instance.
- When provided, neomodel will not create a driver instance but use this one instead.
- """
- if driver:
- self.driver = driver
- if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME:
- self._database_name = config.DATABASE_NAME
- elif url:
- self._parse_driver_from_url(url=url)
-
- self._pid = os.getpid()
- self._active_transaction = None
- # Set to default database if it hasn't been set before
- if self._database_name is None:
- self._database_name = DEFAULT_DATABASE
-
- # Getting the information about the database version requires a connection to the database
- self._database_version = None
- self._database_edition = None
- self._update_database_version()
-
- def _parse_driver_from_url(self, url: str) -> None:
- """Parse the driver information from the given URL and initialize the driver.
-
- Args:
- url (str): The URL to parse.
-
- Raises:
- ValueError: If the URL format is not as expected.
-
- Returns:
- None - Sets the driver and database_name as class properties
- """
- p_start = url.replace(":", "", 1).find(":") + 2
- p_end = url.rfind("@")
- password = url[p_start:p_end]
- url = url.replace(password, quote(password))
- parsed_url = urlparse(url)
-
- valid_schemas = [
- "bolt",
- "bolt+s",
- "bolt+ssc",
- "bolt+routing",
- "neo4j",
- "neo4j+s",
- "neo4j+ssc",
- ]
-
- if parsed_url.netloc.find("@") > -1 and parsed_url.scheme in valid_schemas:
- credentials, hostname = parsed_url.netloc.rsplit("@", 1)
- username, password = credentials.split(":")
- password = unquote(password)
- database_name = parsed_url.path.strip("/")
- else:
- raise ValueError(
- f"Expecting url format: bolt://user:password@localhost:7687 got {url}"
- )
-
- options = {
- "auth": basic_auth(username, password),
- "connection_acquisition_timeout": config.CONNECTION_ACQUISITION_TIMEOUT,
- "connection_timeout": config.CONNECTION_TIMEOUT,
- "keep_alive": config.KEEP_ALIVE,
- "max_connection_lifetime": config.MAX_CONNECTION_LIFETIME,
- "max_connection_pool_size": config.MAX_CONNECTION_POOL_SIZE,
- "max_transaction_retry_time": config.MAX_TRANSACTION_RETRY_TIME,
- "resolver": config.RESOLVER,
- "user_agent": config.USER_AGENT,
- }
-
- if "+s" not in parsed_url.scheme:
- options["encrypted"] = config.ENCRYPTED
- options["trusted_certificates"] = config.TRUSTED_CERTIFICATES
-
- self.driver = GraphDatabase.driver(
- parsed_url.scheme + "://" + hostname, **options
- )
- self.url = url
- # The database name can be provided through the url or the config
- if database_name == "":
- if hasattr(config, "DATABASE_NAME") and config.DATABASE_NAME:
- self._database_name = config.DATABASE_NAME
- else:
- self._database_name = database_name
-
- def close_connection(self):
- """
- Closes the currently open driver.
- The driver should always be closed at the end of the application's lifecyle.
- """
- self._database_version = None
- self._database_edition = None
- self._database_name = None
- self.driver.close()
- self.driver = None
-
- @property
- def database_version(self):
- if self._database_version is None:
- self._update_database_version()
-
- return self._database_version
-
- @property
- def database_edition(self):
- if self._database_edition is None:
- self._update_database_version()
-
- return self._database_edition
-
- @property
- def transaction(self):
- """
- Returns the current transaction object
- """
- return TransactionProxy(self)
-
- @property
- def write_transaction(self):
- return TransactionProxy(self, access_mode="WRITE")
-
- @property
- def read_transaction(self):
- return TransactionProxy(self, access_mode="READ")
-
- def impersonate(self, user: str) -> "ImpersonationHandler":
- """All queries executed within this context manager will be executed as impersonated user
-
- Args:
- user (str): User to impersonate
-
- Returns:
- ImpersonationHandler: Context manager to set/unset the user to impersonate
- """
- if self.database_edition != "enterprise":
- raise FeatureNotSupported(
- "Impersonation is only available in Neo4j Enterprise edition"
- )
- return ImpersonationHandler(self, impersonated_user=user)
-
- @ensure_connection
- def begin(self, access_mode=None, **parameters):
- """
- Begins a new transaction. Raises SystemError if a transaction is already active.
- """
- if (
- hasattr(self, "_active_transaction")
- and self._active_transaction is not None
- ):
- raise SystemError("Transaction in progress")
- self._session = self.driver.session(
- default_access_mode=access_mode,
- database=self._database_name,
- impersonated_user=self.impersonated_user,
- **parameters,
- )
- self._active_transaction = self._session.begin_transaction()
-
- @ensure_connection
- def commit(self):
- """
- Commits the current transaction and closes its session
-
- :return: last_bookmarks
- """
- try:
- self._active_transaction.commit()
- last_bookmarks = self._session.last_bookmarks()
- finally:
- # In case when something went wrong during
- # committing changes to the database
- # we have to close an active transaction and session.
- self._active_transaction.close()
- self._session.close()
- self._active_transaction = None
- self._session = None
-
- return last_bookmarks
-
- @ensure_connection
- def rollback(self):
- """
- Rolls back the current transaction and closes its session
- """
- try:
- self._active_transaction.rollback()
- finally:
- # In case when something went wrong during changes rollback,
- # we have to close an active transaction and session
- self._active_transaction.close()
- self._session.close()
- self._active_transaction = None
- self._session = None
-
- def _update_database_version(self):
- """
- Updates the database server information when it is required
- """
- try:
- results = self.cypher_query(
- "CALL dbms.components() yield versions, edition return versions[0], edition"
- )
- self._database_version = results[0][0][0]
- self._database_edition = results[0][0][1]
- except ServiceUnavailable:
- # The database server is not running yet
- pass
-
- def _object_resolution(self, object_to_resolve):
- """
- Performs in place automatic object resolution on a result
- returned by cypher_query.
-
- The function operates recursively in order to be able to resolve Nodes
- within nested list structures and Path objects. Not meant to be called
- directly, used primarily by _result_resolution.
-
- :param object_to_resolve: A result as returned by cypher_query.
- :type Any:
-
- :return: An instantiated object.
- """
- # Below is the original comment that came with the code extracted in
- # this method. It is not very clear but I decided to keep it just in
- # case
- #
- #
- # For some reason, while the type of `a_result_attribute[1]`
- # as reported by the neo4j driver is `Node` for Node-type data
- # retrieved from the database.
- # When the retrieved data are Relationship-Type,
- # the returned type is `abc.[REL_LABEL]` which is however
- # a descendant of Relationship.
- # Consequently, the type checking was changed for both
- # Node, Relationship objects
- if isinstance(object_to_resolve, Node):
- return self._NODE_CLASS_REGISTRY[
- frozenset(object_to_resolve.labels)
- ].inflate(object_to_resolve)
-
- if isinstance(object_to_resolve, Relationship):
- rel_type = frozenset([object_to_resolve.type])
- return self._NODE_CLASS_REGISTRY[rel_type].inflate(object_to_resolve)
-
- if isinstance(object_to_resolve, Path):
- from .path import NeomodelPath
-
- return NeomodelPath(object_to_resolve)
-
- if isinstance(object_to_resolve, list):
- return self._result_resolution([object_to_resolve])
-
- return object_to_resolve
-
- def _result_resolution(self, result_list):
- """
- Performs in place automatic object resolution on a set of results
- returned by cypher_query.
-
- The function operates recursively in order to be able to resolve Nodes
- within nested list structures. Not meant to be called directly,
- used primarily by cypher_query.
-
- :param result_list: A list of results as returned by cypher_query.
- :type list:
-
- :return: A list of instantiated objects.
- """
-
- # Object resolution occurs in-place
- for a_result_item in enumerate(result_list):
- for a_result_attribute in enumerate(a_result_item[1]):
- try:
- # Primitive types should remain primitive types,
- # Nodes to be resolved to native objects
- resolved_object = a_result_attribute[1]
-
- resolved_object = self._object_resolution(resolved_object)
-
- result_list[a_result_item[0]][
- a_result_attribute[0]
- ] = resolved_object
-
- except KeyError as exc:
- # Not being able to match the label set of a node with a known object results
- # in a KeyError in the internal dictionary used for resolution. If it is impossible
- # to match, then raise an exception with more details about the error.
- if isinstance(a_result_attribute[1], Node):
- raise NodeClassNotDefined(
- a_result_attribute[1], self._NODE_CLASS_REGISTRY
- ) from exc
-
- if isinstance(a_result_attribute[1], Relationship):
- raise RelationshipClassNotDefined(
- a_result_attribute[1], self._NODE_CLASS_REGISTRY
- ) from exc
-
- return result_list
-
- @ensure_connection
- def cypher_query(
- self,
- query,
- params=None,
- handle_unique=True,
- retry_on_session_expire=False,
- resolve_objects=False,
- ):
- """
- Runs a query on the database and returns a list of results and their headers.
-
- :param query: A CYPHER query
- :type: str
- :param params: Dictionary of parameters
- :type: dict
- :param handle_unique: Whether or not to raise UniqueProperty exception on Cypher's ConstraintValidation errors
- :type: bool
- :param retry_on_session_expire: Whether or not to attempt the same query again if the transaction has expired.
- If you use neomodel with your own driver, you must catch SessionExpired exceptions yourself and retry with a new driver instance.
- :type: bool
- :param resolve_objects: Whether to attempt to resolve the returned nodes to data model objects automatically
- :type: bool
- """
-
- if self._active_transaction:
- # Use current session is a transaction is currently active
- results, meta = self._run_cypher_query(
- self._active_transaction,
- query,
- params,
- handle_unique,
- retry_on_session_expire,
- resolve_objects,
- )
- else:
- # Otherwise create a new session in a with to dispose of it after it has been run
- with self.driver.session(
- database=self._database_name, impersonated_user=self.impersonated_user
- ) as session:
- results, meta = self._run_cypher_query(
- session,
- query,
- params,
- handle_unique,
- retry_on_session_expire,
- resolve_objects,
- )
-
- return results, meta
-
- def _run_cypher_query(
- self,
- session,
- query,
- params,
- handle_unique,
- retry_on_session_expire,
- resolve_objects,
- ):
- try:
- # Retrieve the data
- start = time.time()
- response = session.run(query, params)
- results, meta = [list(r.values()) for r in response], response.keys()
- end = time.time()
-
- if resolve_objects:
- # Do any automatic resolution required
- results = self._result_resolution(results)
-
- except ClientError as e:
- if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed":
- if "already exists with label" in e.message and handle_unique:
- raise UniqueProperty(e.message) from e
-
- raise ConstraintValidationFailed(e.message) from e
- exc_info = sys.exc_info()
- raise exc_info[1].with_traceback(exc_info[2])
- except SessionExpired:
- if retry_on_session_expire:
- self.set_connection(url=self.url)
- return self.cypher_query(
- query=query,
- params=params,
- handle_unique=handle_unique,
- retry_on_session_expire=False,
- )
- raise
-
- tte = end - start
- if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float(
- os.environ.get("NEOMODEL_SLOW_QUERIES", 0)
- ):
- logger.debug(
- "query: "
- + query
- + "\nparams: "
- + repr(params)
- + f"\ntook: {tte:.2g}s\n"
- )
-
- return results, meta
-
- def get_id_method(self) -> str:
- if self.database_version.startswith("4"):
- return "id"
- else:
- return "elementId"
-
- def list_indexes(self, exclude_token_lookup=False) -> Sequence[dict]:
- """Returns all indexes existing in the database
-
- Arguments:
- exclude_token_lookup[bool]: Exclude automatically create token lookup indexes
-
- Returns:
- Sequence[dict]: List of dictionaries, each entry being an index definition
- """
- indexes, meta_indexes = self.cypher_query("SHOW INDEXES")
- indexes_as_dict = [dict(zip(meta_indexes, row)) for row in indexes]
-
- if exclude_token_lookup:
- indexes_as_dict = [
- obj for obj in indexes_as_dict if obj["type"] != "LOOKUP"
- ]
-
- return indexes_as_dict
-
- def list_constraints(self) -> Sequence[dict]:
- """Returns all constraints existing in the database
-
- Returns:
- Sequence[dict]: List of dictionaries, each entry being a constraint definition
- """
- constraints, meta_constraints = self.cypher_query("SHOW CONSTRAINTS")
- constraints_as_dict = [dict(zip(meta_constraints, row)) for row in constraints]
-
- return constraints_as_dict
-
- def version_is_higher_than(self, version_tag: str) -> bool:
- """Returns true if the database version is higher or equal to a given tag
-
- Args:
- version_tag (str): The version to compare against
-
- Returns:
- bool: True if the database version is higher or equal to the given version
- """
- return version_tag_to_integer(self.database_version) >= version_tag_to_integer(
- version_tag
- )
-
- def edition_is_enterprise(self) -> bool:
- """Returns true if the database edition is enterprise
-
- Returns:
- bool: True if the database edition is enterprise
- """
- return self.database_edition == "enterprise"
-
-
-class TransactionProxy:
- bookmarks: Optional[Bookmarks] = None
-
- def __init__(self, db, access_mode=None):
- self.db = db
- self.access_mode = access_mode
-
- @ensure_connection
- def __enter__(self):
- self.db.begin(access_mode=self.access_mode, bookmarks=self.bookmarks)
- self.bookmarks = None
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- if exc_value:
- self.db.rollback()
-
- if (
- exc_type is ClientError
- and exc_value.code == "Neo.ClientError.Schema.ConstraintValidationFailed"
- ):
- raise UniqueProperty(exc_value.message)
-
- if not exc_value:
- self.last_bookmark = self.db.commit()
-
- def __call__(self, func):
- def wrapper(*args, **kwargs):
- with self:
- return func(*args, **kwargs)
-
- return wrapper
-
- @property
- def with_bookmark(self):
- return BookmarkingTransactionProxy(self.db, self.access_mode)
-
-
-class ImpersonationHandler:
- def __init__(self, db, impersonated_user: str):
- self.db = db
- self.impersonated_user = impersonated_user
-
- def __enter__(self):
- self.db.impersonated_user = self.impersonated_user
- return self
-
- def __exit__(self, exception_type, exception_value, exception_traceback):
- self.db.impersonated_user = None
-
- print("\nException type:", exception_type)
- print("\nException value:", exception_value)
- print("\nTraceback:", exception_traceback)
-
- def __call__(self, func):
- def wrapper(*args, **kwargs):
- with self:
- return func(*args, **kwargs)
-
- return wrapper
-
-
-class BookmarkingTransactionProxy(TransactionProxy):
- def __call__(self, func):
- def wrapper(*args, **kwargs):
- self.bookmarks = kwargs.pop("bookmarks", None)
-
- with self:
- result = func(*args, **kwargs)
- self.last_bookmark = None
-
- return result, self.last_bookmark
-
- return wrapper
+OUTGOING, INCOMING, EITHER = 1, -1, 0
def deprecated(message):
diff --git a/pyproject.toml b/pyproject.toml
index f15c28b8..acc36850 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,6 @@ authors = [
maintainers = [
{name = "Marius Conjeaud", email = "marius.conjeaud@outlook.com"},
{name = "Athanasios Anastasiou", email = "athanastasiou@gmail.com"},
- {name = "Cristina Escalante"},
]
description = "An object mapper for the neo4j graph database."
readme = "README.md"
@@ -35,7 +34,9 @@ changelog = "https://github.com/neo4j-contrib/neomodel/releases"
[project.optional-dependencies]
dev = [
+ "unasync",
"pytest>=7.1",
+ "pytest-asyncio",
"pytest-cov>=4.0",
"pre-commit",
"black",
@@ -61,7 +62,7 @@ testpaths = "test"
[tool.isort]
profile = 'black'
-src_paths = ['neomodel']
+src_paths = ['neomodel','test']
[tool.pylint.'MESSAGES CONTROL']
disable = 'missing-module-docstring,redefined-builtin,missing-class-docstring,missing-function-docstring,consider-using-f-string,line-too-long'
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 2e7a31f0..bf3fa116 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,6 +1,7 @@
# neomodel
-e .[pandas,numpy]
+unasync>=0.5.0
pytest>=7.1
pytest-cov>=4.0
pre-commit
diff --git a/test/_async_compat/__init__.py b/test/_async_compat/__init__.py
new file mode 100644
index 00000000..342678c3
--- /dev/null
+++ b/test/_async_compat/__init__.py
@@ -0,0 +1,17 @@
+from .mark_decorator import (
+ AsyncTestDecorators,
+ TestDecorators,
+ mark_async_session_auto_fixture,
+ mark_async_test,
+ mark_sync_session_auto_fixture,
+ mark_sync_test,
+)
+
+__all__ = [
+ "AsyncTestDecorators",
+ "mark_async_test",
+ "mark_sync_test",
+ "TestDecorators",
+ "mark_async_session_auto_fixture",
+ "mark_sync_session_auto_fixture",
+]
diff --git a/test/_async_compat/mark_decorator.py b/test/_async_compat/mark_decorator.py
new file mode 100644
index 00000000..a8c5eead
--- /dev/null
+++ b/test/_async_compat/mark_decorator.py
@@ -0,0 +1,21 @@
+import pytest
+import pytest_asyncio
+
+mark_async_test = pytest.mark.asyncio
+mark_async_session_auto_fixture = pytest_asyncio.fixture(scope="session", autouse=True)
+mark_sync_session_auto_fixture = pytest.fixture(scope="session", autouse=True)
+
+
+def mark_sync_test(f):
+ return f
+
+
+class AsyncTestDecorators:
+ mark_async_only_test = mark_async_test
+
+
+class TestDecorators:
+ @staticmethod
+ def mark_async_only_test(f):
+ skip_decorator = pytest.mark.skip("Async only test")
+ return skip_decorator(f)
diff --git a/test/async_/__init__.py b/test/async_/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/async_/conftest.py b/test/async_/conftest.py
new file mode 100644
index 00000000..493ff12c
--- /dev/null
+++ b/test/async_/conftest.py
@@ -0,0 +1,60 @@
+import asyncio
+import os
+import warnings
+from test._async_compat import mark_async_session_auto_fixture
+
+import pytest
+
+from neomodel import adb, config
+
+
+@mark_async_session_auto_fixture
+async def setup_neo4j_session(request, event_loop):
+ """
+ Provides initial connection to the database and sets up the rest of the test suite
+
+ :param request: The request object. Please see `_
+ :type Request object: For more information please see `_
+ """
+
+ warnings.simplefilter("default")
+
+ config.DATABASE_URL = os.environ.get(
+ "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687"
+ )
+
+ # Clear the database if required
+ database_is_populated, _ = await adb.cypher_query(
+ "MATCH (a) return count(a)>0 as database_is_populated"
+ )
+ if database_is_populated[0][0] and not request.config.getoption("resetdb"):
+ raise SystemError(
+ "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb."
+ )
+
+ await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True)
+
+ await adb.install_all_labels()
+
+ await adb.cypher_query(
+ "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED"
+ )
+ db_edition = await adb.database_edition
+ if db_edition == "enterprise":
+ await adb.cypher_query("GRANT ROLE publisher TO troygreene")
+ await adb.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin")
+
+
+@mark_async_session_auto_fixture
+async def cleanup(event_loop):
+ yield
+ await adb.close_connection()
+
+
+@pytest.fixture(scope="session")
+def event_loop():
+ """Overrides pytest default function scoped event loop"""
+ policy = asyncio.get_event_loop_policy()
+ loop = policy.new_event_loop()
+ yield loop
+ loop.close()
diff --git a/test/async_/test_alias.py b/test/async_/test_alias.py
new file mode 100644
index 00000000..b2b9a3a8
--- /dev/null
+++ b/test/async_/test_alias.py
@@ -0,0 +1,34 @@
+from test._async_compat import mark_async_test
+
+from neomodel import AliasProperty, AsyncStructuredNode, StringProperty
+
+
+class MagicProperty(AliasProperty):
+ def setup(self):
+ self.owner.setup_hook_called = True
+
+
+class AliasTestNode(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ full_name = AliasProperty(to="name")
+ long_name = MagicProperty(to="name")
+
+
+@mark_async_test
+async def test_property_setup_hook():
+ timmy = await AliasTestNode(long_name="timmy").save()
+ assert AliasTestNode.setup_hook_called
+ assert timmy.name == "timmy"
+
+
+@mark_async_test
+async def test_alias():
+ jim = await AliasTestNode(full_name="Jim").save()
+ assert jim.name == "Jim"
+ assert jim.full_name == "Jim"
+ assert "full_name" not in AliasTestNode.deflate(jim.__properties__)
+ jim = await AliasTestNode.nodes.get(full_name="Jim")
+ assert jim
+ assert jim.name == "Jim"
+ assert jim.full_name == "Jim"
+ assert "full_name" not in AliasTestNode.deflate(jim.__properties__)
diff --git a/test/async_/test_batch.py b/test/async_/test_batch.py
new file mode 100644
index 00000000..653dce0d
--- /dev/null
+++ b/test/async_/test_batch.py
@@ -0,0 +1,138 @@
+from test._async_compat import mark_async_test
+
+from pytest import raises
+
+from neomodel import (
+ AsyncRelationshipFrom,
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ IntegerProperty,
+ StringProperty,
+ UniqueIdProperty,
+ config,
+)
+from neomodel._async_compat.util import AsyncUtil
+from neomodel.exceptions import DeflateError, UniqueProperty
+
+config.AUTO_INSTALL_LABELS = True
+
+
+class UniqueUser(AsyncStructuredNode):
+ uid = UniqueIdProperty()
+ name = StringProperty()
+ age = IntegerProperty()
+
+
+@mark_async_test
+async def test_unique_id_property_batch():
+ users = await UniqueUser.create(
+ {"name": "bob", "age": 2}, {"name": "ben", "age": 3}
+ )
+
+ assert users[0].uid != users[1].uid
+
+ users = await UniqueUser.get_or_create(
+ {"uid": users[0].uid}, {"name": "bill", "age": 4}
+ )
+
+ assert users[0].name == "bob"
+ assert users[1].uid
+
+
+class Customer(AsyncStructuredNode):
+ email = StringProperty(unique_index=True, required=True)
+ age = IntegerProperty(index=True)
+
+
+@mark_async_test
+async def test_batch_create():
+ users = await Customer.create(
+ {"email": "jim1@aol.com", "age": 11},
+ {"email": "jim2@aol.com", "age": 7},
+ {"email": "jim3@aol.com", "age": 9},
+ {"email": "jim4@aol.com", "age": 7},
+ {"email": "jim5@aol.com", "age": 99},
+ )
+ assert len(users) == 5
+ assert users[0].age == 11
+ assert users[1].age == 7
+ assert users[1].email == "jim2@aol.com"
+ assert await Customer.nodes.get(email="jim1@aol.com")
+
+
+@mark_async_test
+async def test_batch_create_or_update():
+ users = await Customer.create_or_update(
+ {"email": "merge1@aol.com", "age": 11},
+ {"email": "merge2@aol.com"},
+ {"email": "merge3@aol.com", "age": 1},
+ {"email": "merge2@aol.com", "age": 2},
+ )
+ assert len(users) == 4
+ assert users[1] == users[3]
+ merge_1: Customer = await Customer.nodes.get(email="merge1@aol.com")
+ assert merge_1.age == 11
+
+ more_users = await Customer.create_or_update(
+ {"email": "merge1@aol.com", "age": 22},
+ {"email": "merge4@aol.com", "age": None},
+ )
+ assert len(more_users) == 2
+ assert users[0] == more_users[0]
+ merge_1 = await Customer.nodes.get(email="merge1@aol.com")
+ assert merge_1.age == 22
+
+
+@mark_async_test
+async def test_batch_validation():
+ # test validation in batch create
+ with raises(DeflateError):
+ await Customer.create(
+ {"email": "jim1@aol.com", "age": "x"},
+ )
+
+
+@mark_async_test
+async def test_batch_index_violation():
+ for u in await Customer.nodes:
+ await u.delete()
+
+ users = await Customer.create(
+ {"email": "jim6@aol.com", "age": 3},
+ )
+ assert users
+ with raises(UniqueProperty):
+ await Customer.create(
+ {"email": "jim6@aol.com", "age": 3},
+ {"email": "jim7@aol.com", "age": 5},
+ )
+
+ # not found
+ if AsyncUtil.is_async_code:
+ assert not await Customer.nodes.filter(email="jim7@aol.com").check_bool()
+ else:
+ assert not Customer.nodes.filter(email="jim7@aol.com")
+
+
+class Dog(AsyncStructuredNode):
+ name = StringProperty(required=True)
+ owner = AsyncRelationshipTo("Person", "owner")
+
+
+class Person(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ pets = AsyncRelationshipFrom("Dog", "owner")
+
+
+@mark_async_test
+async def test_get_or_create_with_rel():
+ create_bob = await Person.get_or_create({"name": "Bob"})
+ bob = create_bob[0]
+ bobs_gizmo = await Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets)
+
+ create_tim = await Person.get_or_create({"name": "Tim"})
+ tim = create_tim[0]
+ tims_gizmo = await Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets)
+
+ # not the same gizmo
+ assert bobs_gizmo[0] != tims_gizmo[0]
diff --git a/test/async_/test_cardinality.py b/test/async_/test_cardinality.py
new file mode 100644
index 00000000..4ce02ad4
--- /dev/null
+++ b/test/async_/test_cardinality.py
@@ -0,0 +1,187 @@
+from test._async_compat import mark_async_test
+
+from pytest import raises
+
+from neomodel import (
+ AsyncOne,
+ AsyncOneOrMore,
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncZeroOrMore,
+ AsyncZeroOrOne,
+ AttemptedCardinalityViolation,
+ CardinalityViolation,
+ IntegerProperty,
+ StringProperty,
+ adb,
+)
+
+
+class HairDryer(AsyncStructuredNode):
+ version = IntegerProperty()
+
+
+class ScrewDriver(AsyncStructuredNode):
+ version = IntegerProperty()
+
+
+class Car(AsyncStructuredNode):
+ version = IntegerProperty()
+
+
+class Monkey(AsyncStructuredNode):
+ name = StringProperty()
+ dryers = AsyncRelationshipTo("HairDryer", "OWNS_DRYER", cardinality=AsyncZeroOrMore)
+ driver = AsyncRelationshipTo(
+ "ScrewDriver", "HAS_SCREWDRIVER", cardinality=AsyncZeroOrOne
+ )
+ car = AsyncRelationshipTo("Car", "HAS_CAR", cardinality=AsyncOneOrMore)
+ toothbrush = AsyncRelationshipTo(
+ "ToothBrush", "HAS_TOOTHBRUSH", cardinality=AsyncOne
+ )
+
+
+class ToothBrush(AsyncStructuredNode):
+ name = StringProperty()
+
+
+@mark_async_test
+async def test_cardinality_zero_or_more():
+ m = await Monkey(name="tim").save()
+ assert await m.dryers.all() == []
+ single_dryer = await m.driver.single()
+ assert single_dryer is None
+ h = await HairDryer(version=1).save()
+
+ await m.dryers.connect(h)
+ assert len(await m.dryers.all()) == 1
+ single_dryer = await m.dryers.single()
+ assert single_dryer.version == 1
+
+ await m.dryers.disconnect(h)
+ assert await m.dryers.all() == []
+ single_dryer = await m.driver.single()
+ assert single_dryer is None
+
+ h2 = await HairDryer(version=2).save()
+ await m.dryers.connect(h)
+ await m.dryers.connect(h2)
+ await m.dryers.disconnect_all()
+ assert await m.dryers.all() == []
+ single_dryer = await m.driver.single()
+ assert single_dryer is None
+
+
+@mark_async_test
+async def test_cardinality_zero_or_one():
+ m = await Monkey(name="bob").save()
+ assert await m.driver.all() == []
+ single_driver = await m.driver.single()
+ assert await m.driver.single() is None
+ h = await ScrewDriver(version=1).save()
+
+ await m.driver.connect(h)
+ assert len(await m.driver.all()) == 1
+ single_driver = await m.driver.single()
+ assert single_driver.version == 1
+
+ j = await ScrewDriver(version=2).save()
+ with raises(AttemptedCardinalityViolation):
+ await m.driver.connect(j)
+
+ await m.driver.reconnect(h, j)
+ single_driver = await m.driver.single()
+ assert single_driver.version == 2
+
+ # Forcing creation of a second ToothBrush to go around
+ # AttemptedCardinalityViolation
+ await adb.cypher_query(
+ """
+ MATCH (m:Monkey WHERE m.name="bob")
+ CREATE (s:ScrewDriver {version:3})
+ WITH m, s
+ CREATE (m)-[:HAS_SCREWDRIVER]->(s)
+ """
+ )
+ with raises(
+ CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: 2."
+ ):
+ await m.driver.all()
+
+
+@mark_async_test
+async def test_cardinality_one_or_more():
+ m = await Monkey(name="jerry").save()
+
+ with raises(CardinalityViolation):
+ await m.car.all()
+
+ with raises(CardinalityViolation):
+ await m.car.single()
+
+ c = await Car(version=2).save()
+ await m.car.connect(c)
+ single_car = await m.car.single()
+ assert single_car.version == 2
+
+ cars = await m.car.all()
+ assert len(cars) == 1
+
+ with raises(AttemptedCardinalityViolation):
+ await m.car.disconnect(c)
+
+ d = await Car(version=3).save()
+ await m.car.connect(d)
+ cars = await m.car.all()
+ assert len(cars) == 2
+
+ await m.car.disconnect(d)
+ cars = await m.car.all()
+ assert len(cars) == 1
+
+
+@mark_async_test
+async def test_cardinality_one():
+ m = await Monkey(name="jerry").save()
+
+ with raises(
+ CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: none."
+ ):
+ await m.toothbrush.all()
+
+ with raises(CardinalityViolation):
+ await m.toothbrush.single()
+
+ b = await ToothBrush(name="Jim").save()
+ await m.toothbrush.connect(b)
+ single_toothbrush = await m.toothbrush.single()
+ assert single_toothbrush.name == "Jim"
+
+ x = await ToothBrush(name="Jim").save()
+ with raises(AttemptedCardinalityViolation):
+ await m.toothbrush.connect(x)
+
+ with raises(AttemptedCardinalityViolation):
+ await m.toothbrush.disconnect(b)
+
+ with raises(AttemptedCardinalityViolation):
+ await m.toothbrush.disconnect_all()
+
+ # Forcing creation of a second ToothBrush to go around
+ # AttemptedCardinalityViolation
+ await adb.cypher_query(
+ """
+ MATCH (m:Monkey WHERE m.name="jerry")
+ CREATE (t:ToothBrush {name:"Jim"})
+ WITH m, t
+ CREATE (m)-[:HAS_TOOTHBRUSH]->(t)
+ """
+ )
+ with raises(
+ CardinalityViolation, match=r"CardinalityViolation: Expected: .*, got: 2."
+ ):
+ await m.toothbrush.all()
+
+ jp = Monkey(name="Jean-Pierre")
+ with raises(ValueError, match="Node has not been saved cannot connect!"):
+ await jp.toothbrush.connect(b)
diff --git a/test/async_/test_connection.py b/test/async_/test_connection.py
new file mode 100644
index 00000000..a2eded7d
--- /dev/null
+++ b/test/async_/test_connection.py
@@ -0,0 +1,148 @@
+import os
+from test._async_compat import mark_async_test
+from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME
+
+import pytest
+from neo4j import AsyncDriver, AsyncGraphDatabase
+from neo4j.debug import watch
+
+from neomodel import AsyncStructuredNode, StringProperty, adb, config
+
+
+@mark_async_test
+@pytest.fixture(autouse=True)
+async def setup_teardown():
+ yield
+ # Teardown actions after tests have run
+ # Reconnect to initial URL for potential subsequent tests
+ await adb.close_connection()
+ await adb.set_connection(url=config.DATABASE_URL)
+
+
+@pytest.fixture(autouse=True, scope="session")
+def neo4j_logging():
+ with watch("neo4j"):
+ yield
+
+
+@mark_async_test
+async def get_current_database_name() -> str:
+ """
+ Fetches the name of the currently active database from the Neo4j database.
+
+ Returns:
+ - str: The name of the current database.
+ """
+ results, meta = await adb.cypher_query("CALL db.info")
+ results_as_dict = [dict(zip(meta, row)) for row in results]
+
+ return results_as_dict[0]["name"]
+
+
+class Pastry(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+
+
+@mark_async_test
+async def test_set_connection_driver_works():
+ # Verify that current connection is up
+ assert await Pastry(name="Chocolatine").save()
+ await adb.close_connection()
+
+ # Test connection using a driver
+ await adb.set_connection(
+ driver=AsyncGraphDatabase().driver(
+ NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
+ )
+ )
+ assert await Pastry(name="Croissant").save()
+
+
+@mark_async_test
+async def test_config_driver_works():
+ # Verify that current connection is up
+ assert await Pastry(name="Chausson aux pommes").save()
+ await adb.close_connection()
+
+ # Test connection using a driver defined in config
+ driver: AsyncDriver = AsyncGraphDatabase().driver(
+ NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
+ )
+
+ config.DRIVER = driver
+ assert await Pastry(name="Grignette").save()
+
+ # Clear config
+ # No need to close connection - pytest teardown will do it
+ config.DRIVER = None
+
+
+@mark_async_test
+async def test_connect_to_non_default_database():
+ if not await adb.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition - no multi database in CE")
+ database_name = "pastries"
+ await adb.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS")
+ await adb.close_connection()
+
+ # Set database name in url - for url init only
+ await adb.set_connection(url=f"{config.DATABASE_URL}/{database_name}")
+ assert await get_current_database_name() == "pastries"
+
+ await adb.close_connection()
+
+ # Set database name in config - for both url and driver init
+ config.DATABASE_NAME = database_name
+
+ # url init
+ await adb.set_connection(url=config.DATABASE_URL)
+ assert await get_current_database_name() == "pastries"
+
+ await adb.close_connection()
+
+ # driver init
+ await adb.set_connection(
+ driver=AsyncGraphDatabase().driver(
+ NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
+ )
+ )
+ assert await get_current_database_name() == "pastries"
+
+ # Clear config
+ # No need to close connection - pytest teardown will do it
+ config.DATABASE_NAME = None
+
+
+@mark_async_test
+@pytest.mark.parametrize(
+ "url", ["bolt://user:password", "http://user:password@localhost:7687"]
+)
+async def test_wrong_url_format(url):
+ with pytest.raises(
+ ValueError,
+ match=rf"Expecting url format: bolt://user:password@localhost:7687 got {url}",
+ ):
+ await adb.set_connection(url=url)
+
+
+@mark_async_test
+@pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"])
+async def test_connect_to_aura(protocol):
+ cypher_return = "hello world"
+ default_cypher_query = f"RETURN '{cypher_return}'"
+ await adb.close_connection()
+
+ await _set_connection(protocol=protocol)
+ result, _ = await adb.cypher_query(default_cypher_query)
+
+ assert len(result) > 0
+ assert result[0][0] == cypher_return
+
+
+async def _set_connection(protocol):
+ AURA_TEST_DB_USER = os.environ["AURA_TEST_DB_USER"]
+ AURA_TEST_DB_PASSWORD = os.environ["AURA_TEST_DB_PASSWORD"]
+ AURA_TEST_DB_HOSTNAME = os.environ["AURA_TEST_DB_HOSTNAME"]
+
+ database_url = f"{protocol}://{AURA_TEST_DB_USER}:{AURA_TEST_DB_PASSWORD}@{AURA_TEST_DB_HOSTNAME}"
+ await adb.set_connection(url=database_url)
diff --git a/test/async_/test_contrib/__init__.py b/test/async_/test_contrib/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/async_/test_contrib/test_semi_structured.py b/test/async_/test_contrib/test_semi_structured.py
new file mode 100644
index 00000000..3b88fcad
--- /dev/null
+++ b/test/async_/test_contrib/test_semi_structured.py
@@ -0,0 +1,35 @@
+from test._async_compat import mark_async_test
+
+from neomodel import IntegerProperty, StringProperty
+from neomodel.contrib import AsyncSemiStructuredNode
+
+
+class UserProf(AsyncSemiStructuredNode):
+ email = StringProperty(unique_index=True, required=True)
+ age = IntegerProperty(index=True)
+
+
+class Dummy(AsyncSemiStructuredNode):
+ pass
+
+
+@mark_async_test
+async def test_to_save_to_model_with_required_only():
+ u = UserProf(email="dummy@test.com")
+ assert await u.save()
+
+
+@mark_async_test
+async def test_save_to_model_with_extras():
+ u = UserProf(email="jim@test.com", age=3, bar=99)
+ u.foo = True
+ assert await u.save()
+ u = await UserProf.nodes.get(age=3)
+ assert u.foo is True
+ assert u.bar == 99
+
+
+@mark_async_test
+async def test_save_empty_model():
+ dummy = Dummy()
+ assert await dummy.save()
diff --git a/test/test_contrib/test_spatial_datatypes.py b/test/async_/test_contrib/test_spatial_datatypes.py
similarity index 100%
rename from test/test_contrib/test_spatial_datatypes.py
rename to test/async_/test_contrib/test_spatial_datatypes.py
diff --git a/test/async_/test_contrib/test_spatial_properties.py b/test/async_/test_contrib/test_spatial_properties.py
new file mode 100644
index 00000000..a6103639
--- /dev/null
+++ b/test/async_/test_contrib/test_spatial_properties.py
@@ -0,0 +1,291 @@
+"""
+Provides a test case for issue 374 - "Support for Point property type".
+
+For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374
+"""
+
+import random
+from test._async_compat import mark_async_test
+
+import neo4j.spatial
+import pytest
+
+import neomodel
+import neomodel.contrib.spatial_properties
+
+from .test_spatial_datatypes import (
+ basic_type_assertions,
+ check_and_skip_neo4j_least_version,
+)
+
+
+def test_spatial_point_property():
+ """
+ Tests that specific modes of instantiation fail as expected.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ with pytest.raises(ValueError, match=r"Invalid CRS\(None\)"):
+ a_point_property = neomodel.contrib.spatial_properties.PointProperty()
+
+ with pytest.raises(ValueError, match=r"Invalid CRS\(crs_isaak\)"):
+ a_point_property = neomodel.contrib.spatial_properties.PointProperty(
+ crs="crs_isaak"
+ )
+
+ with pytest.raises(TypeError, match="Invalid default value"):
+ a_point_property = neomodel.contrib.spatial_properties.PointProperty(
+ default=(0.0, 0.0), crs="cartesian"
+ )
+
+
+def test_inflate():
+ """
+ Tests that the marshalling from neo4j to neomodel data types works as expected.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # The test is repeatable enough to try and standardise it. The same test is repeated with the assertions in
+ # `basic_type_assertions` and different messages to be able to localise the exception.
+ #
+ # Array of points to inflate and messages when things go wrong
+ values_from_db = [
+ (
+ neo4j.spatial.CartesianPoint((0.0, 0.0)),
+ "Expected Neomodel 2d cartesian point when inflating 2d cartesian neo4j point",
+ ),
+ (
+ neo4j.spatial.CartesianPoint((0.0, 0.0, 0.0)),
+ "Expected Neomodel 3d cartesian point when inflating 3d cartesian neo4j point",
+ ),
+ (
+ neo4j.spatial.WGS84Point((0.0, 0.0)),
+ "Expected Neomodel 2d geographical point when inflating 2d geographical neo4j point",
+ ),
+ (
+ neo4j.spatial.WGS84Point((0.0, 0.0, 0.0)),
+ "Expected Neomodel 3d geographical point inflating 3d geographical neo4j point",
+ ),
+ ]
+
+ # Run the above tests
+ for a_value in values_from_db:
+ expected_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ tuple(a_value[0]),
+ crs=neomodel.contrib.spatial_properties.SRID_TO_CRS[a_value[0].srid],
+ )
+ inflated_point = neomodel.contrib.spatial_properties.PointProperty(
+ crs=neomodel.contrib.spatial_properties.SRID_TO_CRS[a_value[0].srid]
+ ).inflate(a_value[0])
+ basic_type_assertions(
+ expected_point,
+ inflated_point,
+ "{}, received {}".format(a_value[1], inflated_point),
+ )
+
+
+def test_deflate():
+ """
+ Tests that the marshalling from neomodel to neo4j data types works as expected
+ :return:
+ """
+ # Please see inline comments in `test_inflate`. This test function is 90% to that one with very minor differences.
+ #
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ CRS_TO_SRID = dict(
+ [
+ (value, key)
+ for key, value in neomodel.contrib.spatial_properties.SRID_TO_CRS.items()
+ ]
+ )
+ # Values to construct and expect during deflation
+ values_from_neomodel = [
+ (
+ neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0), crs="cartesian"
+ ),
+ "Expected Neo4J 2d cartesian point when deflating Neomodel 2d cartesian point",
+ ),
+ (
+ neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="cartesian-3d"
+ ),
+ "Expected Neo4J 3d cartesian point when deflating Neomodel 3d cartesian point",
+ ),
+ (
+ neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0), crs="wgs-84"),
+ "Expected Neo4J 2d geographical point when deflating Neomodel 2d geographical point",
+ ),
+ (
+ neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="wgs-84-3d"
+ ),
+ "Expected Neo4J 3d geographical point when deflating Neomodel 3d geographical point",
+ ),
+ ]
+
+ # Run the above tests.
+ for a_value in values_from_neomodel:
+ expected_point = neo4j.spatial.Point(tuple(a_value[0].coords[0]))
+ expected_point.srid = CRS_TO_SRID[a_value[0].crs]
+ deflated_point = neomodel.contrib.spatial_properties.PointProperty(
+ crs=a_value[0].crs
+ ).deflate(a_value[0])
+ basic_type_assertions(
+ expected_point,
+ deflated_point,
+ "{}, received {}".format(a_value[1], deflated_point),
+ check_neo4j_points=True,
+ )
+
+
+@mark_async_test
+async def test_default_value():
+ """
+ Tests that the default value passing mechanism works as expected with NeomodelPoint values.
+ :return:
+ """
+
+ def get_some_point():
+ return neomodel.contrib.spatial_properties.NeomodelPoint(
+ (random.random(), random.random())
+ )
+
+ class LocalisableEntity(neomodel.AsyncStructuredNode):
+ """
+ A very simple entity to try out the default value assignment.
+ """
+
+ identifier = neomodel.UniqueIdProperty()
+ location = neomodel.contrib.spatial_properties.PointProperty(
+ crs="cartesian", default=get_some_point
+ )
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Save an object
+ an_object = await LocalisableEntity().save()
+ coords = an_object.location.coords[0]
+ # Retrieve it
+ retrieved_object = await LocalisableEntity.nodes.get(
+ identifier=an_object.identifier
+ )
+ # Check against an independently created value
+ assert (
+ retrieved_object.location
+ == neomodel.contrib.spatial_properties.NeomodelPoint(coords)
+ ), ("Default value assignment failed.")
+
+
+@mark_async_test
+async def test_array_of_points():
+ """
+ Tests that Arrays of Points work as expected.
+
+ :return:
+ """
+
+ class AnotherLocalisableEntity(neomodel.AsyncStructuredNode):
+ """
+ A very simple entity with an array of locations
+ """
+
+ identifier = neomodel.UniqueIdProperty()
+ locations = neomodel.ArrayProperty(
+ neomodel.contrib.spatial_properties.PointProperty(crs="cartesian")
+ )
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ an_object = await AnotherLocalisableEntity(
+ locations=[
+ neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)),
+ neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)),
+ ]
+ ).save()
+
+ retrieved_object = await AnotherLocalisableEntity.nodes.get(
+ identifier=an_object.identifier
+ )
+
+ assert (
+ type(retrieved_object.locations) is list
+ ), "Array of Points definition failed."
+ assert retrieved_object.locations == [
+ neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0)),
+ neomodel.contrib.spatial_properties.NeomodelPoint((1.0, 0.0)),
+ ], "Array of Points incorrect values."
+
+
+@mark_async_test
+async def test_simple_storage_retrieval():
+ """
+ Performs a simple Create, Retrieve via .save(), .get() which, due to the way Q objects operate, tests the
+ __copy__, __deepcopy__ operations of NeomodelPoint.
+ :return:
+ """
+
+ class TestStorageRetrievalProperty(neomodel.AsyncStructuredNode):
+ uid = neomodel.UniqueIdProperty()
+ description = neomodel.StringProperty()
+ location = neomodel.contrib.spatial_properties.PointProperty(crs="cartesian")
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ a_restaurant = await TestStorageRetrievalProperty(
+ description="Milliways",
+ location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0)),
+ ).save()
+
+ a_property = await TestStorageRetrievalProperty.nodes.get(
+ location=neomodel.contrib.spatial_properties.NeomodelPoint((0, 0))
+ )
+
+ assert a_restaurant.description == a_property.description
+
+
+def test_equality_with_other_objects():
+ """
+ Performs equality tests and ensures tha ``NeomodelPoint`` can be compared with ShapelyPoint and NeomodelPoint only.
+ """
+ try:
+ import shapely.geometry
+ from shapely import __version__
+ except ImportError:
+ pytest.skip("Shapely module not present")
+
+ if int("".join(__version__.split(".")[0:3])) < 200:
+ pytest.skip(f"Shapely 2.0 not present (Current version is {__version__}")
+
+ assert neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0, 0)
+ ) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0)
+ assert neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0, 0)
+ ) == shapely.geometry.Point((0, 0))
diff --git a/test/async_/test_cypher.py b/test/async_/test_cypher.py
new file mode 100644
index 00000000..c078c8d5
--- /dev/null
+++ b/test/async_/test_cypher.py
@@ -0,0 +1,165 @@
+import builtins
+from test._async_compat import mark_async_test
+
+import pytest
+from neo4j.exceptions import ClientError as CypherError
+from numpy import ndarray
+from pandas import DataFrame, Series
+
+from neomodel import AsyncStructuredNode, StringProperty, adb
+from neomodel._async_compat.util import AsyncUtil
+
+
+class User2(AsyncStructuredNode):
+ name = StringProperty()
+ email = StringProperty()
+
+
+class UserPandas(AsyncStructuredNode):
+ name = StringProperty()
+ email = StringProperty()
+
+
+class UserNP(AsyncStructuredNode):
+ name = StringProperty()
+ email = StringProperty()
+
+
+@pytest.fixture
+def hide_available_pkg(monkeypatch, request):
+ import_orig = builtins.__import__
+
+ def mocked_import(name, *args, **kwargs):
+ if name == request.param:
+ raise ImportError()
+ return import_orig(name, *args, **kwargs)
+
+ monkeypatch.setattr(builtins, "__import__", mocked_import)
+
+
+@mark_async_test
+async def test_cypher():
+ """
+ test result format is backward compatible with earlier versions of neomodel
+ """
+
+ jim = await User2(email="jim1@test.com").save()
+ data, meta = await jim.cypher(
+ f"MATCH (a) WHERE {await adb.get_id_method()}(a)=$self RETURN a.email"
+ )
+ assert data[0][0] == "jim1@test.com"
+ assert "a.email" in meta
+
+ data, meta = await jim.cypher(
+ f"""
+ MATCH (a) WHERE {await adb.get_id_method()}(a)=$self
+ MATCH (a)<-[:USER2]-(b)
+ RETURN a, b, 3
+ """
+ )
+ assert "a" in meta and "b" in meta
+
+
+@mark_async_test
+async def test_cypher_syntax_error():
+ jim = await User2(email="jim1@test.com").save()
+ try:
+ await jim.cypher(
+ f"MATCH a WHERE {await adb.get_id_method()}(a)={{self}} RETURN xx"
+ )
+ except CypherError as e:
+ assert hasattr(e, "message")
+ assert hasattr(e, "code")
+ else:
+ assert False, "CypherError not raised."
+
+
+@mark_async_test
+@pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True)
+async def test_pandas_not_installed(hide_available_pkg):
+ # We run only the async version, because this fails on second run
+ # because import error is thrown only when pandas.py is imported
+ if not AsyncUtil.is_async_code:
+ pytest.skip("This test is async only")
+ with pytest.raises(ImportError):
+ with pytest.warns(
+ UserWarning,
+ match="The neomodel.integration.pandas module expects pandas to be installed",
+ ):
+ from neomodel.integration.pandas import to_dataframe
+
+ _ = to_dataframe(await adb.cypher_query("MATCH (a) RETURN a.name AS name"))
+
+
+@mark_async_test
+async def test_pandas_integration():
+ from neomodel.integration.pandas import to_dataframe, to_series
+
+ jimla = await UserPandas(email="jimla@test.com", name="jimla").save()
+ jimlo = await UserPandas(email="jimlo@test.com", name="jimlo").save()
+
+ # Test to_dataframe
+ df = to_dataframe(
+ await adb.cypher_query(
+ "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email"
+ )
+ )
+
+ assert isinstance(df, DataFrame)
+ assert df.shape == (2, 2)
+ assert df["name"].tolist() == ["jimla", "jimlo"]
+
+ # Also test passing an index and dtype to to_dataframe
+ df = to_dataframe(
+ await adb.cypher_query(
+ "MATCH (a:UserPandas) RETURN a.name AS name, a.email AS email"
+ ),
+ index=df["email"],
+ dtype=str,
+ )
+
+ assert df.index.inferred_type == "string"
+
+ # Next test to_series
+ series = to_series(
+ await adb.cypher_query("MATCH (a:UserPandas) RETURN a.name AS name")
+ )
+
+ assert isinstance(series, Series)
+ assert series.shape == (2,)
+ assert df["name"].tolist() == ["jimla", "jimlo"]
+
+
+@mark_async_test
+@pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True)
+async def test_numpy_not_installed(hide_available_pkg):
+ # We run only the async version, because this fails on second run
+ # because import error is thrown only when numpy.py is imported
+ if not AsyncUtil.is_async_code:
+ pytest.skip("This test is async only")
+ with pytest.raises(ImportError):
+ with pytest.warns(
+ UserWarning,
+ match="The neomodel.integration.numpy module expects numpy to be installed",
+ ):
+ from neomodel.integration.numpy import to_ndarray
+
+ _ = to_ndarray(await adb.cypher_query("MATCH (a) RETURN a.name AS name"))
+
+
+@mark_async_test
+async def test_numpy_integration():
+ from neomodel.integration.numpy import to_ndarray
+
+ jimly = await UserNP(email="jimly@test.com", name="jimly").save()
+ jimlu = await UserNP(email="jimlu@test.com", name="jimlu").save()
+
+ array = to_ndarray(
+ await adb.cypher_query(
+ "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email ORDER BY name"
+ )
+ )
+
+ assert isinstance(array, ndarray)
+ assert array.shape == (2, 2)
+ assert array[0][0] == "jimlu"
diff --git a/test/async_/test_database_management.py b/test/async_/test_database_management.py
new file mode 100644
index 00000000..5159642a
--- /dev/null
+++ b/test/async_/test_database_management.py
@@ -0,0 +1,81 @@
+from test._async_compat import mark_async_test
+
+import pytest
+from neo4j.exceptions import AuthError
+
+from neomodel import (
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ IntegerProperty,
+ StringProperty,
+ adb,
+)
+
+
+class City(AsyncStructuredNode):
+ name = StringProperty()
+
+
+class InCity(AsyncStructuredRel):
+ creation_year = IntegerProperty(index=True)
+
+
+class Venue(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ creator = StringProperty(index=True)
+ in_city = AsyncRelationshipTo(City, relation_type="IN", model=InCity)
+
+
+@mark_async_test
+async def test_clear_database():
+ venue = await Venue(name="Royal Albert Hall", creator="Queen Victoria").save()
+ city = await City(name="London").save()
+ await venue.in_city.connect(city)
+
+ # Clear only the data
+ await adb.clear_neo4j_database()
+ database_is_populated, _ = await adb.cypher_query(
+ "MATCH (a) return count(a)>0 as database_is_populated"
+ )
+
+ assert database_is_populated[0][0] is False
+
+ await adb.install_all_labels()
+ indexes = await adb.list_indexes(exclude_token_lookup=True)
+ constraints = await adb.list_constraints()
+ assert len(indexes) > 0
+ assert len(constraints) > 0
+
+ # Clear constraints and indexes too
+ await adb.clear_neo4j_database(clear_constraints=True, clear_indexes=True)
+
+ indexes = await adb.list_indexes(exclude_token_lookup=True)
+ constraints = await adb.list_constraints()
+ assert len(indexes) == 0
+ assert len(constraints) == 0
+
+
+@mark_async_test
+async def test_change_password():
+ prev_password = "foobarbaz"
+ new_password = "newpassword"
+ prev_url = f"bolt://neo4j:{prev_password}@localhost:7687"
+ new_url = f"bolt://neo4j:{new_password}@localhost:7687"
+
+ await adb.change_neo4j_password("neo4j", new_password)
+ await adb.close_connection()
+
+ await adb.set_connection(url=new_url)
+ await adb.close_connection()
+
+ with pytest.raises(AuthError):
+ await adb.set_connection(url=prev_url)
+
+ await adb.close_connection()
+
+ await adb.set_connection(url=new_url)
+ await adb.change_neo4j_password("neo4j", prev_password)
+ await adb.close_connection()
+
+ await adb.set_connection(url=prev_url)
diff --git a/test/async_/test_dbms_awareness.py b/test/async_/test_dbms_awareness.py
new file mode 100644
index 00000000..66a8fc5f
--- /dev/null
+++ b/test/async_/test_dbms_awareness.py
@@ -0,0 +1,37 @@
+from test._async_compat import mark_async_test
+
+import pytest
+
+from neomodel import adb
+from neomodel.util import version_tag_to_integer
+
+
+@mark_async_test
+async def test_version_awareness():
+ db_version = await adb.database_version
+ if db_version != "5.7.0":
+ pytest.skip("Testing a specific database version")
+ assert db_version == "5.7.0"
+ assert await adb.version_is_higher_than("5.7")
+ assert await adb.version_is_higher_than("5.6.0")
+ assert await adb.version_is_higher_than("5")
+ assert await adb.version_is_higher_than("4")
+
+ assert not await adb.version_is_higher_than("5.8")
+
+
+@mark_async_test
+async def test_edition_awareness():
+ db_edition = await adb.database_edition
+ if db_edition == "enterprise":
+ assert await adb.edition_is_enterprise()
+ else:
+ assert not await adb.edition_is_enterprise()
+
+
+def test_version_tag_to_integer():
+ assert version_tag_to_integer("5.7.1") == 50701
+ assert version_tag_to_integer("5.1") == 50100
+ assert version_tag_to_integer("5") == 50000
+ assert version_tag_to_integer("5.14.1") == 51401
+ assert version_tag_to_integer("5.14-aura") == 51400
diff --git a/test/async_/test_driver_options.py b/test/async_/test_driver_options.py
new file mode 100644
index 00000000..df378f98
--- /dev/null
+++ b/test/async_/test_driver_options.py
@@ -0,0 +1,52 @@
+from test._async_compat import mark_async_test
+
+import pytest
+from neo4j.exceptions import ClientError
+from pytest import raises
+
+from neomodel import adb
+from neomodel.exceptions import FeatureNotSupported
+
+
+@mark_async_test
+async def test_impersonate():
+ if not await adb.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
+ with await adb.impersonate(user="troygreene"):
+ results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'")
+ assert results[0][0] == "Doo Wacko !"
+
+
+@mark_async_test
+async def test_impersonate_unauthorized():
+ if not await adb.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
+ with await adb.impersonate(user="unknownuser"):
+ with raises(ClientError):
+ _ = await adb.cypher_query("RETURN 'Gabagool'")
+
+
+@mark_async_test
+async def test_impersonate_multiple_transactions():
+ if not await adb.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
+ with await adb.impersonate(user="troygreene"):
+ async with adb.transaction:
+ results, _ = await adb.cypher_query("RETURN 'Doo Wacko !'")
+ assert results[0][0] == "Doo Wacko !"
+
+ async with adb.transaction:
+ results, _ = await adb.cypher_query("SHOW CURRENT USER")
+ assert results[0][0] == "troygreene"
+
+ results, _ = await adb.cypher_query("SHOW CURRENT USER")
+ assert results[0][0] == "neo4j"
+
+
+@mark_async_test
+async def test_impersonate_community():
+ if await adb.edition_is_enterprise():
+ pytest.skip("Skipping test for enterprise edition")
+ with raises(FeatureNotSupported):
+ with await adb.impersonate(user="troygreene"):
+ _ = await adb.cypher_query("RETURN 'Gabagoogoo'")
diff --git a/test/async_/test_exceptions.py b/test/async_/test_exceptions.py
new file mode 100644
index 00000000..e948f76c
--- /dev/null
+++ b/test/async_/test_exceptions.py
@@ -0,0 +1,31 @@
+import pickle
+from test._async_compat import mark_async_test
+
+from neomodel import AsyncStructuredNode, DoesNotExist, StringProperty
+
+
+class EPerson(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+
+
+@mark_async_test
+async def test_object_does_not_exist():
+ try:
+ await EPerson.nodes.get(name="johnny")
+ except EPerson.DoesNotExist as e:
+ pickle_instance = pickle.dumps(e)
+ assert pickle_instance
+ assert pickle.loads(pickle_instance)
+ assert isinstance(pickle.loads(pickle_instance), DoesNotExist)
+ else:
+ assert False, "Person.DoesNotExist not raised."
+
+
+def test_pickle_does_not_exist():
+ try:
+ raise EPerson.DoesNotExist("My Test Message")
+ except EPerson.DoesNotExist as e:
+ pickle_instance = pickle.dumps(e)
+ assert pickle_instance
+ assert pickle.loads(pickle_instance)
+ assert isinstance(pickle.loads(pickle_instance), DoesNotExist)
diff --git a/test/async_/test_hooks.py b/test/async_/test_hooks.py
new file mode 100644
index 00000000..09a87403
--- /dev/null
+++ b/test/async_/test_hooks.py
@@ -0,0 +1,35 @@
+from test._async_compat import mark_async_test
+
+from neomodel import AsyncStructuredNode, StringProperty
+
+HOOKS_CALLED = {}
+
+
+class HookTest(AsyncStructuredNode):
+ name = StringProperty()
+
+ def post_create(self):
+ HOOKS_CALLED["post_create"] = 1
+
+ def pre_save(self):
+ HOOKS_CALLED["pre_save"] = 1
+
+ def post_save(self):
+ HOOKS_CALLED["post_save"] = 1
+
+ def pre_delete(self):
+ HOOKS_CALLED["pre_delete"] = 1
+
+ def post_delete(self):
+ HOOKS_CALLED["post_delete"] = 1
+
+
+@mark_async_test
+async def test_hooks():
+ ht = await HookTest(name="k").save()
+ await ht.delete()
+ assert "pre_save" in HOOKS_CALLED
+ assert "post_save" in HOOKS_CALLED
+ assert "post_create" in HOOKS_CALLED
+ assert "pre_delete" in HOOKS_CALLED
+ assert "post_delete" in HOOKS_CALLED
diff --git a/test/async_/test_indexing.py b/test/async_/test_indexing.py
new file mode 100644
index 00000000..9e0d8f37
--- /dev/null
+++ b/test/async_/test_indexing.py
@@ -0,0 +1,88 @@
+from test._async_compat import mark_async_test
+
+import pytest
+from pytest import raises
+
+from neomodel import (
+ AsyncStructuredNode,
+ IntegerProperty,
+ StringProperty,
+ UniqueProperty,
+ adb,
+)
+from neomodel.exceptions import ConstraintValidationFailed
+
+
+class Human(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ age = IntegerProperty(index=True)
+
+
+@mark_async_test
+async def test_unique_error():
+ await adb.install_labels(Human)
+ await Human(name="j1m", age=13).save()
+ try:
+ await Human(name="j1m", age=14).save()
+ except UniqueProperty as e:
+ assert str(e).find("j1m")
+ assert str(e).find("name")
+ else:
+ assert False, "UniqueProperty not raised."
+
+
+@mark_async_test
+async def test_existence_constraint_error():
+ if not await adb.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
+ await adb.cypher_query(
+ "CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL"
+ )
+ with raises(ConstraintValidationFailed, match=r"must have the property"):
+ await Human(name="Scarlett").save()
+
+ await adb.cypher_query("DROP CONSTRAINT test_existence_constraint")
+
+
+@mark_async_test
+async def test_optional_properties_dont_get_indexed():
+ await Human(name="99", age=99).save()
+ h = await Human.nodes.get(age=99)
+ assert h
+ assert h.name == "99"
+
+ await Human(name="98", age=98).save()
+ h = await Human.nodes.get(age=98)
+ assert h
+ assert h.name == "98"
+
+
+@mark_async_test
+async def test_escaped_chars():
+ _name = "sarah:test"
+ await Human(name=_name, age=3).save()
+ r = await Human.nodes.filter(name=_name)
+ assert r[0].name == _name
+
+
+@mark_async_test
+async def test_does_not_exist():
+ with raises(Human.DoesNotExist):
+ await Human.nodes.get(name="XXXX")
+
+
+@mark_async_test
+async def test_custom_label_name():
+ class Giraffe(AsyncStructuredNode):
+ __label__ = "Giraffes"
+ name = StringProperty(unique_index=True)
+
+ jim = await Giraffe(name="timothy").save()
+ node = await Giraffe.nodes.get(name="timothy")
+ assert node.name == jim.name
+
+ class SpecialGiraffe(Giraffe):
+ power = StringProperty()
+
+ # custom labels aren't inherited
+ assert SpecialGiraffe.__label__ == "SpecialGiraffe"
diff --git a/test/async_/test_issue112.py b/test/async_/test_issue112.py
new file mode 100644
index 00000000..8ba1ce03
--- /dev/null
+++ b/test/async_/test_issue112.py
@@ -0,0 +1,19 @@
+from test._async_compat import mark_async_test
+
+from neomodel import AsyncRelationshipTo, AsyncStructuredNode
+
+
+class SomeModel(AsyncStructuredNode):
+ test = AsyncRelationshipTo("SomeModel", "SELF")
+
+
+@mark_async_test
+async def test_len_relationship():
+ t1 = await SomeModel().save()
+ t2 = await SomeModel().save()
+
+ await t1.test.connect(t2)
+ l = len(await t1.test.all())
+
+ assert l
+ assert l == 1
diff --git a/test/async_/test_issue283.py b/test/async_/test_issue283.py
new file mode 100644
index 00000000..8106a796
--- /dev/null
+++ b/test/async_/test_issue283.py
@@ -0,0 +1,524 @@
+"""
+Provides a test case for issue 283 - "Inheritance breaks".
+
+The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/283
+More information about the same issue at:
+https://github.com/aanastasiou/neomodelInheritanceTest
+
+The following example uses a recursive relationship for economy, but the
+idea remains the same: "Instantiate the correct type of node at the end of
+a relationship as specified by the model"
+"""
+import random
+from test._async_compat import mark_async_test
+
+import pytest
+
+from neomodel import (
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ DateTimeProperty,
+ FloatProperty,
+ RelationshipClassNotDefined,
+ RelationshipClassRedefined,
+ StringProperty,
+ UniqueIdProperty,
+ adb,
+)
+from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined
+
+try:
+ basestring
+except NameError:
+ basestring = str
+
+
+# Set up a very simple model for the tests
+class PersonalRelationship(AsyncStructuredRel):
+ """
+ A very simple relationship between two basePersons that simply records
+ the date at which an acquaintance was established.
+ This relationship should be carried over to anything that inherits from
+ basePerson without any further effort.
+ """
+
+ on_date = DateTimeProperty(default_now=True)
+
+
+class BasePerson(AsyncStructuredNode):
+ """
+ Base class for defining some basic sort of an actor.
+ """
+
+ name = StringProperty(required=True, unique_index=True)
+ friends_with = AsyncRelationshipTo(
+ "BasePerson", "FRIENDS_WITH", model=PersonalRelationship
+ )
+
+
+class TechnicalPerson(BasePerson):
+ """
+ A Technical person specialises BasePerson by adding their expertise.
+ """
+
+ expertise = StringProperty(required=True)
+
+
+class PilotPerson(BasePerson):
+ """
+ A pilot person specialises BasePerson by adding the type of airplane they
+ can operate.
+ """
+
+ airplane = StringProperty(required=True)
+
+
+class BaseOtherPerson(AsyncStructuredNode):
+ """
+ An obviously "wrong" class of actor to befriend BasePersons with.
+ """
+
+ car_color = StringProperty(required=True)
+
+
+class SomePerson(BaseOtherPerson):
+ """
+ Concrete class that simply derives from BaseOtherPerson.
+ """
+
+ pass
+
+
+# Test cases
+@mark_async_test
+async def test_automatic_result_resolution():
+ """
+ Node objects at the end of relationships are instantiated to their
+ corresponding Python object.
+ """
+
+ # Create a few entities
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+ B = (
+ await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})
+ )[0]
+ C = (
+ await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})
+ )[0]
+
+ # Add connections
+ await A.friends_with.connect(B)
+ await B.friends_with.connect(C)
+ await C.friends_with.connect(A)
+
+ test = await A.friends_with
+
+ # If A is friends with B, then A's friends_with objects should be
+ # TechnicalPerson (!NOT basePerson!)
+ assert type((await A.friends_with)[0]) is TechnicalPerson
+
+ await A.delete()
+ await B.delete()
+ await C.delete()
+
+
+@mark_async_test
+async def test_recursive_automatic_result_resolution():
+ """
+ Node objects are instantiated to native Python objects, both at the top
+ level of returned results and in the case where they are returned within
+ lists.
+ """
+
+ # Create a few entities
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpier", "expertise": "Grumpiness"}
+ )
+ )[0]
+ B = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Happier", "expertise": "Grumpiness"}
+ )
+ )[0]
+ C = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Sleepier", "expertise": "Pillows"}
+ )
+ )[0]
+ D = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Sneezier", "expertise": "Pillows"}
+ )
+ )[0]
+
+ # Retrieve mixed results, both at the top level and nested
+ L, _ = await adb.cypher_query(
+ "MATCH (a:TechnicalPerson) "
+ "WHERE a.expertise='Grumpiness' "
+ "WITH collect(a) as Alpha "
+ "MATCH (b:TechnicalPerson) "
+ "WHERE b.expertise='Pillows' "
+ "WITH Alpha, collect(b) as Beta "
+ "RETURN [Alpha, [Beta, [Beta, ['Banana', "
+ "Alpha]]]]",
+ resolve_objects=True,
+ )
+
+ # Assert that a Node returned deep in a nested list structure is of the
+ # correct type
+ assert type(L[0][0][0][1][0][0][0][0]) is TechnicalPerson
+ # Assert that primitive data types remain primitive data types
+ assert issubclass(type(L[0][0][0][1][0][1][0][1][0][0]), basestring)
+
+ await A.delete()
+ await B.delete()
+ await C.delete()
+ await D.delete()
+
+
+@mark_async_test
+async def test_validation_with_inheritance_from_db():
+ """
+ Objects descending from the specified class of a relationship's end-node are
+ also perfectly valid to appear as end-node values too
+ """
+
+ # Create a few entities
+ # Technical Persons
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+ B = (
+ await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})
+ )[0]
+ C = (
+ await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})
+ )[0]
+
+ # Pilot Persons
+ D = (
+ await PilotPerson.get_or_create(
+ {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"}
+ )
+ )[0]
+ E = (
+ await PilotPerson.get_or_create(
+ {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"}
+ )
+ )[0]
+
+ # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine
+
+ # TechnicalPersons befriend Technical Persons
+ await A.friends_with.connect(B)
+ await B.friends_with.connect(C)
+ await C.friends_with.connect(A)
+
+ # Pilot Persons befriend Pilot Persons
+ await D.friends_with.connect(E)
+
+ # Technical Persons befriend Pilot Persons
+ await A.friends_with.connect(D)
+ await E.friends_with.connect(C)
+
+ # This now means that friends_with of a TechnicalPerson can
+ # either be TechnicalPerson or Pilot Person (!NOT basePerson!)
+
+ assert (type((await A.friends_with)[0]) is TechnicalPerson) or (
+ type((await A.friends_with)[0]) is PilotPerson
+ )
+ assert (type((await A.friends_with)[1]) is TechnicalPerson) or (
+ type((await A.friends_with)[1]) is PilotPerson
+ )
+ assert type((await D.friends_with)[0]) is PilotPerson
+
+ await A.delete()
+ await B.delete()
+ await C.delete()
+ await D.delete()
+ await E.delete()
+
+
+@mark_async_test
+async def test_validation_enforcement_to_db():
+ """
+ If a connection between wrong types is attempted, raise an exception
+ """
+
+ # Create a few entities
+ # Technical Persons
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+ B = (
+ await TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"})
+ )[0]
+ C = (
+ await TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"})
+ )[0]
+
+ # Pilot Persons
+ D = (
+ await PilotPerson.get_or_create(
+ {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"}
+ )
+ )[0]
+ E = (
+ await PilotPerson.get_or_create(
+ {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"}
+ )
+ )[0]
+
+ # Some Person
+ F = await SomePerson(car_color="Blue").save()
+
+ # TechnicalPersons can befriend PilotPersons and vice-versa and that's fine
+ await A.friends_with.connect(B)
+ await B.friends_with.connect(C)
+ await C.friends_with.connect(A)
+ await D.friends_with.connect(E)
+ await A.friends_with.connect(D)
+ await E.friends_with.connect(C)
+
+ # Trying to befriend a Technical Person with Some Person should raise an
+ # exception
+ with pytest.raises(ValueError):
+ await A.friends_with.connect(F)
+
+ await A.delete()
+ await B.delete()
+ await C.delete()
+ await D.delete()
+ await E.delete()
+ await F.delete()
+
+
+@mark_async_test
+async def test_failed_result_resolution():
+ """
+ A Neo4j driver node FROM the database contains labels that are unaware to
+ neomodel's Database class. This condition raises ClassDefinitionNotFound
+ exception.
+ """
+
+ class RandomPerson(BasePerson):
+ randomness = FloatProperty(default=random.random)
+
+ # A Technical Person...
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+
+ # A Random Person...
+ B = (await RandomPerson.get_or_create({"name": "Mad Hatter"}))[0]
+
+ await A.friends_with.connect(B)
+
+ # Simulate the condition where the definition of class RandomPerson is not
+ # known yet.
+ del adb._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])]
+
+ # Now try to instantiate a RandomPerson
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+ with pytest.raises(
+ NodeClassNotDefined,
+ match=r"Node with labels .* does not resolve to any of the known objects.*",
+ ):
+ friends = await A.friends_with.all()
+ for some_friend in friends:
+ print(some_friend.name)
+
+ await A.delete()
+ await B.delete()
+
+
+@mark_async_test
+async def test_node_label_mismatch():
+ """
+ A Neo4j driver node FROM the database contains a superset of the known
+ labels.
+ """
+
+ class SuperTechnicalPerson(TechnicalPerson):
+ superness = FloatProperty(default=1.0)
+
+ class UltraTechnicalPerson(SuperTechnicalPerson):
+ ultraness = FloatProperty(default=3.1415928)
+
+ # Create a TechnicalPerson...
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+ # ...that is connected to an UltraTechnicalPerson
+ F = await UltraTechnicalPerson(
+ name="Chewbaka", expertise="Aarrr wgh ggwaaah"
+ ).save()
+ await A.friends_with.connect(F)
+
+ # Forget about the UltraTechnicalPerson
+ del adb._NODE_CLASS_REGISTRY[
+ frozenset(
+ [
+ "UltraTechnicalPerson",
+ "SuperTechnicalPerson",
+ "TechnicalPerson",
+ "BasePerson",
+ ]
+ )
+ ]
+
+ # Recall a TechnicalPerson and enumerate its friends.
+ # One of them is UltraTechnicalPerson which would be returned as a valid
+ # node to a friends_with query but is currently unknown to the node class registry.
+ A = (
+ await TechnicalPerson.get_or_create(
+ {"name": "Grumpy", "expertise": "Grumpiness"}
+ )
+ )[0]
+ with pytest.raises(NodeClassNotDefined):
+ friends = await A.friends_with.all()
+ for some_friend in friends:
+ print(some_friend.name)
+
+
+def test_attempted_class_redefinition():
+ """
+ A StructuredNode class is attempted to be redefined.
+ """
+
+ def redefine_class_locally():
+ # Since this test has already set up a class hierarchy in its global scope, we will try to redefine
+ # SomePerson here.
+ # The internal structure of the SomePerson entity does not matter at all here.
+ class SomePerson(BaseOtherPerson):
+ uid = UniqueIdProperty()
+
+ with pytest.raises(
+ NodeClassAlreadyDefined,
+ match=r"Class .* with labels .* already defined:.*",
+ ):
+ redefine_class_locally()
+
+
+@mark_async_test
+async def test_relationship_result_resolution():
+ """
+ A query returning a "Relationship" object can now instantiate it to a data model class
+ """
+ # Test specific data
+ A = await PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save()
+ B = await PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save()
+ C = await PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save()
+ D = await PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save()
+ E = await PilotPerson(name="Edward Granville", airplane="Gee Bee Model R").save()
+
+ await A.friends_with.connect(B)
+ await B.friends_with.connect(C)
+ await C.friends_with.connect(D)
+ await D.friends_with.connect(E)
+
+ query_data = await adb.cypher_query(
+ "MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) "
+ "WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' "
+ "RETURN DISTINCT r",
+ resolve_objects=True,
+ )
+
+ # The relationship here should be properly instantiated to a `PersonalRelationship` object.
+ assert isinstance(query_data[0][0][0], PersonalRelationship)
+
+
+@mark_async_test
+async def test_properly_inherited_relationship():
+ """
+ A relationship class extends an existing relationship model that must extended the same previously associated
+ relationship label.
+ """
+
+ # Extends an existing relationship by adding the "relationship_strength" attribute.
+ # `ExtendedPersonalRelationship` will now substitute `PersonalRelationship` EVERYWHERE in the system.
+ class ExtendedPersonalRelationship(PersonalRelationship):
+ relationship_strength = FloatProperty(default=random.random)
+
+ # Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson
+ class ExtendedSomePerson(SomePerson):
+ friends_with = AsyncRelationshipTo(
+ "BaseOtherPerson",
+ "FRIENDS_WITH",
+ model=ExtendedPersonalRelationship,
+ )
+
+ # Test specific data
+ A = await ExtendedSomePerson(name="Michael Knight", car_color="Black").save()
+ B = await ExtendedSomePerson(name="Luke Duke", car_color="Orange").save()
+ C = await ExtendedSomePerson(name="Michael Schumacher", car_color="Red").save()
+
+ await A.friends_with.connect(B)
+ await A.friends_with.connect(C)
+
+ query_data = await adb.cypher_query(
+ "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) "
+ "RETURN DISTINCT r",
+ resolve_objects=True,
+ )
+
+ assert isinstance(query_data[0][0][0], ExtendedPersonalRelationship)
+
+
+def test_improperly_inherited_relationship():
+ """
+ Attempting to re-define an existing relationship with a completely unrelated class.
+ :return:
+ """
+
+ class NewRelationship(AsyncStructuredRel):
+ profile_match_factor = FloatProperty()
+
+ with pytest.raises(
+ RelationshipClassRedefined,
+ match=r"Relationship of type .* redefined as .*",
+ ):
+
+ class NewSomePerson(SomePerson):
+ friends_with = AsyncRelationshipTo(
+ "BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship
+ )
+
+
+@mark_async_test
+async def test_resolve_inexistent_relationship():
+ """
+ Attempting to resolve an inexistent relationship should raise an exception
+ :return:
+ """
+
+ # Forget about the FRIENDS_WITH Relationship.
+ del adb._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])]
+
+ with pytest.raises(
+ RelationshipClassNotDefined,
+ match=r"Relationship of type .* does not resolve to any of the known objects.*",
+ ):
+ query_data = await adb.cypher_query(
+ "MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) "
+ "RETURN DISTINCT r",
+ resolve_objects=True,
+ )
diff --git a/test/async_/test_issue600.py b/test/async_/test_issue600.py
new file mode 100644
index 00000000..5f66f39e
--- /dev/null
+++ b/test/async_/test_issue600.py
@@ -0,0 +1,87 @@
+"""
+Provides a test case for issue 600 - "Pull request #592 cause an error in case of relationship inharitance".
+
+The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/600
+"""
+
+from test._async_compat import mark_async_test
+
+from neomodel import AsyncRelationship, AsyncStructuredNode, AsyncStructuredRel
+
+try:
+ basestring
+except NameError:
+ basestring = str
+
+
+class Class1(AsyncStructuredRel):
+ pass
+
+
+class SubClass1(Class1):
+ pass
+
+
+class SubClass2(Class1):
+ pass
+
+
+class RelationshipDefinerSecondSibling(AsyncStructuredNode):
+ rel_1 = AsyncRelationship(
+ "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1
+ )
+ rel_2 = AsyncRelationship(
+ "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1
+ )
+ rel_3 = AsyncRelationship(
+ "RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2
+ )
+
+
+class RelationshipDefinerParentLast(AsyncStructuredNode):
+ rel_2 = AsyncRelationship(
+ "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1
+ )
+ rel_3 = AsyncRelationship(
+ "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2
+ )
+ rel_1 = AsyncRelationship(
+ "RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1
+ )
+
+
+# Test cases
+@mark_async_test
+async def test_relationship_definer_second_sibling():
+ # Create a few entities
+ A = (await RelationshipDefinerSecondSibling.get_or_create({}))[0]
+ B = (await RelationshipDefinerSecondSibling.get_or_create({}))[0]
+ C = (await RelationshipDefinerSecondSibling.get_or_create({}))[0]
+
+ # Add connections
+ await A.rel_1.connect(B)
+ await B.rel_2.connect(C)
+ await C.rel_3.connect(A)
+
+ # Clean up
+ await A.delete()
+ await B.delete()
+ await C.delete()
+
+
+@mark_async_test
+async def test_relationship_definer_parent_last():
+ # Create a few entities
+ A = (await RelationshipDefinerParentLast.get_or_create({}))[0]
+ B = (await RelationshipDefinerParentLast.get_or_create({}))[0]
+ C = (await RelationshipDefinerParentLast.get_or_create({}))[0]
+
+ # Add connections
+ await A.rel_1.connect(B)
+ await B.rel_2.connect(C)
+ await C.rel_3.connect(A)
+
+ # Clean up
+ await A.delete()
+ await B.delete()
+ await C.delete()
diff --git a/test/async_/test_label_drop.py b/test/async_/test_label_drop.py
new file mode 100644
index 00000000..3d64050b
--- /dev/null
+++ b/test/async_/test_label_drop.py
@@ -0,0 +1,47 @@
+from test._async_compat import mark_async_test
+
+from neo4j.exceptions import ClientError
+
+from neomodel import AsyncStructuredNode, StringProperty, adb
+
+
+class ConstraintAndIndex(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ last_name = StringProperty(index=True)
+
+
+@mark_async_test
+async def test_drop_labels():
+ await adb.install_labels(ConstraintAndIndex)
+ constraints_before = await adb.list_constraints()
+ indexes_before = await adb.list_indexes(exclude_token_lookup=True)
+
+ assert len(constraints_before) > 0
+ assert len(indexes_before) > 0
+
+ await adb.remove_all_labels()
+
+ constraints = await adb.list_constraints()
+ indexes = await adb.list_indexes(exclude_token_lookup=True)
+
+ assert len(constraints) == 0
+ assert len(indexes) == 0
+
+ # Recreating all old constraints and indexes
+ for constraint in constraints_before:
+ constraint_type_clause = "UNIQUE"
+ if constraint["type"] == "NODE_PROPERTY_EXISTENCE":
+ constraint_type_clause = "NOT NULL"
+ elif constraint["type"] == "NODE_KEY":
+ constraint_type_clause = "NODE KEY"
+
+ await adb.cypher_query(
+ f'CREATE CONSTRAINT {constraint["name"]} FOR (n:{constraint["labelsOrTypes"][0]}) REQUIRE n.{constraint["properties"][0]} IS {constraint_type_clause}'
+ )
+ for index in indexes_before:
+ try:
+ await adb.cypher_query(
+ f'CREATE INDEX {index["name"]} FOR (n:{index["labelsOrTypes"][0]}) ON (n.{index["properties"][0]})'
+ )
+ except ClientError:
+ pass
diff --git a/test/async_/test_label_install.py b/test/async_/test_label_install.py
new file mode 100644
index 00000000..2d710a19
--- /dev/null
+++ b/test/async_/test_label_install.py
@@ -0,0 +1,193 @@
+from test._async_compat import mark_async_test
+
+import pytest
+
+from neomodel import (
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ StringProperty,
+ UniqueIdProperty,
+ adb,
+)
+from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported
+
+
+class NodeWithIndex(AsyncStructuredNode):
+ name = StringProperty(index=True)
+
+
+class NodeWithConstraint(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+
+
+class NodeWithRelationship(AsyncStructuredNode):
+ ...
+
+
+class IndexedRelationship(AsyncStructuredRel):
+ indexed_rel_prop = StringProperty(index=True)
+
+
+class OtherNodeWithRelationship(AsyncStructuredNode):
+ has_rel = AsyncRelationshipTo(
+ NodeWithRelationship, "INDEXED_REL", model=IndexedRelationship
+ )
+
+
+class AbstractNode(AsyncStructuredNode):
+ __abstract_node__ = True
+ name = StringProperty(unique_index=True)
+
+
+class SomeNotUniqueNode(AsyncStructuredNode):
+ id_ = UniqueIdProperty(db_property="id")
+
+
+@mark_async_test
+async def test_install_all():
+ await adb.drop_constraints()
+ await adb.drop_indexes()
+ await adb.install_labels(AbstractNode)
+ # run install all labels
+ await adb.install_all_labels()
+
+ indexes = await adb.list_indexes()
+ index_names = [index["name"] for index in indexes]
+ assert "index_INDEXED_REL_indexed_rel_prop" in index_names
+
+ constraints = await adb.list_constraints()
+ constraint_names = [constraint["name"] for constraint in constraints]
+ assert "constraint_unique_NodeWithConstraint_name" in constraint_names
+ assert "constraint_unique_SomeNotUniqueNode_id" in constraint_names
+
+ # remove constraint for above test
+ await _drop_constraints_for_label_and_property("NoConstraintsSetup", "name")
+
+
+@mark_async_test
+async def test_install_label_twice(capsys):
+ await adb.drop_constraints()
+ await adb.drop_indexes()
+ expected_std_out = (
+ "{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}"
+ )
+ await adb.install_labels(AbstractNode)
+ await adb.install_labels(AbstractNode)
+
+ await adb.install_labels(NodeWithIndex)
+ await adb.install_labels(NodeWithIndex, quiet=False)
+ captured = capsys.readouterr()
+ assert expected_std_out in captured.out
+
+ await adb.install_labels(NodeWithConstraint)
+ await adb.install_labels(NodeWithConstraint, quiet=False)
+ captured = capsys.readouterr()
+ assert expected_std_out in captured.out
+
+ await adb.install_labels(OtherNodeWithRelationship)
+ await adb.install_labels(OtherNodeWithRelationship, quiet=False)
+ captured = capsys.readouterr()
+ assert expected_std_out in captured.out
+
+ if await adb.version_is_higher_than("5.7"):
+
+ class UniqueIndexRelationship(AsyncStructuredRel):
+ unique_index_rel_prop = StringProperty(unique_index=True)
+
+ class OtherNodeWithUniqueIndexRelationship(AsyncStructuredNode):
+ has_rel = AsyncRelationshipTo(
+ NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship
+ )
+
+ await adb.install_labels(OtherNodeWithUniqueIndexRelationship)
+ await adb.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False)
+ captured = capsys.readouterr()
+ assert expected_std_out in captured.out
+
+
+@mark_async_test
+async def test_install_labels_db_property(capsys):
+ await adb.drop_constraints()
+ await adb.install_labels(SomeNotUniqueNode, quiet=False)
+ captured = capsys.readouterr()
+ assert "id" in captured.out
+ # make sure that the id_ constraint doesn't exist
+ constraint_names = await _drop_constraints_for_label_and_property(
+ "SomeNotUniqueNode", "id_"
+ )
+ assert constraint_names == []
+ # make sure the id constraint exists and can be removed
+ await _drop_constraints_for_label_and_property("SomeNotUniqueNode", "id")
+
+
+@mark_async_test
+async def test_relationship_unique_index_not_supported():
+ if await adb.version_is_higher_than("5.7"):
+ pytest.skip("Not supported before 5.7")
+
+ class UniqueIndexRelationship(AsyncStructuredRel):
+ name = StringProperty(unique_index=True)
+
+ class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode):
+ pass
+
+ with pytest.raises(
+ FeatureNotSupported, match=r".*Please upgrade to Neo4j 5.7 or higher"
+ ):
+
+ class NodeWithUniqueIndexRelationship(AsyncStructuredNode):
+ has_rel = AsyncRelationshipTo(
+ TargetNodeForUniqueIndexRelationship,
+ "UNIQUE_INDEX_REL",
+ model=UniqueIndexRelationship,
+ )
+
+ await adb.install_labels(NodeWithUniqueIndexRelationship)
+
+
+@mark_async_test
+async def test_relationship_unique_index():
+ if not await adb.version_is_higher_than("5.7"):
+ pytest.skip("Not supported before 5.7")
+
+ class UniqueIndexRelationshipBis(AsyncStructuredRel):
+ name = StringProperty(unique_index=True)
+
+ class TargetNodeForUniqueIndexRelationship(AsyncStructuredNode):
+ pass
+
+ class NodeWithUniqueIndexRelationship(AsyncStructuredNode):
+ has_rel = AsyncRelationshipTo(
+ TargetNodeForUniqueIndexRelationship,
+ "UNIQUE_INDEX_REL_BIS",
+ model=UniqueIndexRelationshipBis,
+ )
+
+ await adb.install_labels(NodeWithUniqueIndexRelationship)
+ node1 = await NodeWithUniqueIndexRelationship().save()
+ node2 = await TargetNodeForUniqueIndexRelationship().save()
+ node3 = await TargetNodeForUniqueIndexRelationship().save()
+ rel1 = await node1.has_rel.connect(node2, {"name": "rel1"})
+
+ with pytest.raises(
+ ConstraintValidationFailed,
+ match=r".*already exists with type `UNIQUE_INDEX_REL_BIS` and property `name`.*",
+ ):
+ rel2 = await node1.has_rel.connect(node3, {"name": "rel1"})
+
+
+async def _drop_constraints_for_label_and_property(
+ label: str = None, property: str = None
+):
+ results, meta = await adb.cypher_query("SHOW CONSTRAINTS")
+ results_as_dict = [dict(zip(meta, row)) for row in results]
+ constraint_names = [
+ constraint
+ for constraint in results_as_dict
+ if constraint["labelsOrTypes"] == label and constraint["properties"] == property
+ ]
+ for constraint_name in constraint_names:
+ await adb.cypher_query(f"DROP CONSTRAINT {constraint_name}")
+
+ return constraint_names
diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py
new file mode 100644
index 00000000..87c1ca8e
--- /dev/null
+++ b/test/async_/test_match_api.py
@@ -0,0 +1,569 @@
+from datetime import datetime
+from test._async_compat import mark_async_test
+
+from pytest import raises
+
+from neomodel import (
+ INCOMING,
+ AsyncRelationshipFrom,
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ DateTimeProperty,
+ IntegerProperty,
+ Q,
+ StringProperty,
+)
+from neomodel._async_compat.util import AsyncUtil
+from neomodel.async_.match import (
+ AsyncNodeSet,
+ AsyncQueryBuilder,
+ AsyncTraversal,
+ Optional,
+)
+from neomodel.exceptions import MultipleNodesReturned
+
+
+class SupplierRel(AsyncStructuredRel):
+ since = DateTimeProperty(default=datetime.now)
+ courier = StringProperty()
+
+
+class Supplier(AsyncStructuredNode):
+ name = StringProperty()
+ delivery_cost = IntegerProperty()
+ coffees = AsyncRelationshipTo("Coffee", "COFFEE SUPPLIERS")
+
+
+class Species(AsyncStructuredNode):
+ name = StringProperty()
+ coffees = AsyncRelationshipFrom(
+ "Coffee", "COFFEE SPECIES", model=AsyncStructuredRel
+ )
+
+
+class Coffee(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ price = IntegerProperty()
+ suppliers = AsyncRelationshipFrom(Supplier, "COFFEE SUPPLIERS", model=SupplierRel)
+ species = AsyncRelationshipTo(Species, "COFFEE SPECIES", model=AsyncStructuredRel)
+ id_ = IntegerProperty()
+
+
+class Extension(AsyncStructuredNode):
+ extension = AsyncRelationshipTo("Extension", "extension")
+
+
+@mark_async_test
+async def test_filter_exclude_via_labels():
+ await Coffee(name="Java", price=99).save()
+
+ node_set = AsyncNodeSet(Coffee)
+ qb = await AsyncQueryBuilder(node_set).build_ast()
+
+ results = await qb._execute()
+
+ assert "(coffee:Coffee)" in qb._ast.match
+ assert qb._ast.result_class
+ assert len(results) == 1
+ assert isinstance(results[0], Coffee)
+ assert results[0].name == "Java"
+
+ # with filter and exclude
+ await Coffee(name="Kenco", price=3).save()
+ node_set = node_set.filter(price__gt=2).exclude(price__gt=6, name="Java")
+ qb = await AsyncQueryBuilder(node_set).build_ast()
+
+ results = await qb._execute()
+ assert "(coffee:Coffee)" in qb._ast.match
+ assert "NOT" in qb._ast.where[0]
+ assert len(results) == 1
+ assert results[0].name == "Kenco"
+
+
+@mark_async_test
+async def test_simple_has_via_label():
+ nescafe = await Coffee(name="Nescafe", price=99).save()
+ tesco = await Supplier(name="Tesco", delivery_cost=2).save()
+ await nescafe.suppliers.connect(tesco)
+
+ ns = AsyncNodeSet(Coffee).has(suppliers=True)
+ qb = await AsyncQueryBuilder(ns).build_ast()
+ results = await qb._execute()
+ assert "COFFEE SUPPLIERS" in qb._ast.where[0]
+ assert len(results) == 1
+ assert results[0].name == "Nescafe"
+
+ await Coffee(name="nespresso", price=99).save()
+ ns = AsyncNodeSet(Coffee).has(suppliers=False)
+ qb = await AsyncQueryBuilder(ns).build_ast()
+ results = await qb._execute()
+ assert len(results) > 0
+ assert "NOT" in qb._ast.where[0]
+
+
+@mark_async_test
+async def test_get():
+ await Coffee(name="1", price=3).save()
+ assert await Coffee.nodes.get(name="1")
+
+ with raises(Coffee.DoesNotExist):
+ await Coffee.nodes.get(name="2")
+
+ await Coffee(name="2", price=3).save()
+
+ with raises(MultipleNodesReturned):
+ await Coffee.nodes.get(price=3)
+
+
+@mark_async_test
+async def test_simple_traverse_with_filter():
+ nescafe = await Coffee(name="Nescafe2", price=99).save()
+ tesco = await Supplier(name="Sainsburys", delivery_cost=2).save()
+ await nescafe.suppliers.connect(tesco)
+
+ qb = AsyncQueryBuilder(
+ AsyncNodeSet(source=nescafe).suppliers.match(since__lt=datetime.now())
+ )
+
+ _ast = await qb.build_ast()
+ results = await _ast._execute()
+
+ assert qb._ast.lookup
+ assert qb._ast.match
+ assert qb._ast.return_clause.startswith("suppliers")
+ assert len(results) == 1
+ assert results[0].name == "Sainsburys"
+
+
+@mark_async_test
+async def test_double_traverse():
+ nescafe = await Coffee(name="Nescafe plus", price=99).save()
+ tesco = await Supplier(name="Asda", delivery_cost=2).save()
+ await nescafe.suppliers.connect(tesco)
+ await tesco.coffees.connect(await Coffee(name="Decafe", price=2).save())
+
+ ns = AsyncNodeSet(AsyncNodeSet(source=nescafe).suppliers.match()).coffees.match()
+ qb = await AsyncQueryBuilder(ns).build_ast()
+
+ results = await qb._execute()
+ assert len(results) == 2
+ assert results[0].name == "Decafe"
+ assert results[1].name == "Nescafe plus"
+
+
+@mark_async_test
+async def test_count():
+ await Coffee(name="Nescafe Gold", price=99).save()
+ ast = await AsyncQueryBuilder(AsyncNodeSet(source=Coffee)).build_ast()
+ count = await ast._count()
+ assert count > 0
+
+ await Coffee(name="Kawa", price=27).save()
+ node_set = AsyncNodeSet(source=Coffee)
+ node_set.skip = 1
+ node_set.limit = 1
+ ast = await AsyncQueryBuilder(node_set).build_ast()
+ count = await ast._count()
+ assert count == 1
+
+
+@mark_async_test
+async def test_len_and_iter_and_bool():
+ iterations = 0
+
+ await Coffee(name="Icelands finest").save()
+
+ for c in await Coffee.nodes:
+ iterations += 1
+ await c.delete()
+
+ assert iterations > 0
+
+ assert len(await Coffee.nodes) == 0
+
+
+@mark_async_test
+async def test_slice():
+ for c in await Coffee.nodes:
+ await c.delete()
+
+ await Coffee(name="Icelands finest").save()
+ await Coffee(name="Britains finest").save()
+ await Coffee(name="Japans finest").save()
+
+ # Branching tests because async needs extra brackets
+ if AsyncUtil.is_async_code:
+ assert len(list((await Coffee.nodes)[1:])) == 2
+ assert len(list((await Coffee.nodes)[:1])) == 1
+ assert isinstance((await Coffee.nodes)[1], Coffee)
+ assert isinstance((await Coffee.nodes)[0], Coffee)
+ assert len(list((await Coffee.nodes)[1:2])) == 1
+ else:
+ assert len(list(Coffee.nodes[1:])) == 2
+ assert len(list(Coffee.nodes[:1])) == 1
+ assert isinstance(Coffee.nodes[1], Coffee)
+ assert isinstance(Coffee.nodes[0], Coffee)
+ assert len(list(Coffee.nodes[1:2])) == 1
+
+
+@mark_async_test
+async def test_issue_208():
+ # calls to match persist across queries.
+
+ b = await Coffee(name="basics").save()
+ l = await Supplier(name="lidl").save()
+ a = await Supplier(name="aldi").save()
+
+ await b.suppliers.connect(l, {"courier": "fedex"})
+ await b.suppliers.connect(a, {"courier": "dhl"})
+
+ assert len(await b.suppliers.match(courier="fedex"))
+ assert len(await b.suppliers.match(courier="dhl"))
+
+
+@mark_async_test
+async def test_issue_589():
+ node1 = await Extension().save()
+ node2 = await Extension().save()
+ assert node2 not in await node1.extension
+ await node1.extension.connect(node2)
+ assert node2 in await node1.extension
+
+
+@mark_async_test
+async def test_contains():
+ expensive = await Coffee(price=1000, name="Pricey").save()
+ asda = await Coffee(name="Asda", price=1).save()
+
+ assert expensive in await Coffee.nodes.filter(price__gt=999)
+ assert asda not in await Coffee.nodes.filter(price__gt=999)
+
+ # bad value raises
+ with raises(ValueError, match=r"Expecting StructuredNode instance"):
+ if AsyncUtil.is_async_code:
+ assert await Coffee.nodes.check_contains(2)
+ else:
+ assert 2 in Coffee.nodes
+
+ # unsaved
+ with raises(ValueError, match=r"Unsaved node"):
+ if AsyncUtil.is_async_code:
+ assert await Coffee.nodes.check_contains(Coffee())
+ else:
+ assert Coffee() in Coffee.nodes
+
+
+@mark_async_test
+async def test_order_by():
+ for c in await Coffee.nodes:
+ await c.delete()
+
+ c1 = await Coffee(name="Icelands finest", price=5).save()
+ c2 = await Coffee(name="Britains finest", price=10).save()
+ c3 = await Coffee(name="Japans finest", price=35).save()
+
+ if AsyncUtil.is_async_code:
+ assert ((await Coffee.nodes.order_by("price"))[0]).price == 5
+ assert ((await Coffee.nodes.order_by("-price"))[0]).price == 35
+ else:
+ assert (Coffee.nodes.order_by("price")[0]).price == 5
+ assert (Coffee.nodes.order_by("-price")[0]).price == 35
+
+ ns = Coffee.nodes.order_by("-price")
+ qb = await AsyncQueryBuilder(ns).build_ast()
+ assert qb._ast.order_by
+ ns = ns.order_by(None)
+ qb = await AsyncQueryBuilder(ns).build_ast()
+ assert not qb._ast.order_by
+ ns = ns.order_by("?")
+ qb = await AsyncQueryBuilder(ns).build_ast()
+ assert qb._ast.with_clause == "coffee, rand() as r"
+ assert qb._ast.order_by == "r"
+
+ with raises(
+ ValueError,
+ match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.",
+ ):
+ await Coffee.nodes.order_by("id")
+
+ # Test order by on a relationship
+ l = await Supplier(name="lidl2").save()
+ await l.coffees.connect(c1)
+ await l.coffees.connect(c2)
+ await l.coffees.connect(c3)
+
+ ordered_n = [n for n in await l.coffees.order_by("name")]
+ assert ordered_n[0] == c2
+ assert ordered_n[1] == c1
+ assert ordered_n[2] == c3
+
+
+@mark_async_test
+async def test_extra_filters():
+ for c in await Coffee.nodes:
+ await c.delete()
+
+ c1 = await Coffee(name="Icelands finest", price=5, id_=1).save()
+ c2 = await Coffee(name="Britains finest", price=10, id_=2).save()
+ c3 = await Coffee(name="Japans finest", price=35, id_=3).save()
+ c4 = await Coffee(name="US extra-fine", price=None, id_=4).save()
+
+ coffees_5_10 = await Coffee.nodes.filter(price__in=[10, 5])
+ assert len(coffees_5_10) == 2, "unexpected number of results"
+ assert c1 in coffees_5_10, "doesnt contain 5 price coffee"
+ assert c2 in coffees_5_10, "doesnt contain 10 price coffee"
+
+ finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest")
+ assert len(finest_coffees) == 3, "unexpected number of results"
+ assert c1 in finest_coffees, "doesnt contain 1st finest coffee"
+ assert c2 in finest_coffees, "doesnt contain 2nd finest coffee"
+ assert c3 in finest_coffees, "doesnt contain 3rd finest coffee"
+
+ unpriced_coffees = await Coffee.nodes.filter(price__isnull=True)
+ assert len(unpriced_coffees) == 1, "unexpected number of results"
+ assert c4 in unpriced_coffees, "doesnt contain unpriced coffee"
+
+ coffees_with_id_gte_3 = await Coffee.nodes.filter(id___gte=3)
+ assert len(coffees_with_id_gte_3) == 2, "unexpected number of results"
+ assert c3 in coffees_with_id_gte_3
+ assert c4 in coffees_with_id_gte_3
+
+ with raises(
+ ValueError,
+ match=r".*Neo4j internals like id or element_id are not allowed for use in this operation.",
+ ):
+ await Coffee.nodes.filter(elementId="4:xxx:111").all()
+
+
+def test_traversal_definition_keys_are_valid():
+ muckefuck = Coffee(name="Mukkefuck", price=1)
+
+ with raises(ValueError):
+ AsyncTraversal(
+ muckefuck,
+ "a_name",
+ {
+ "node_class": Supplier,
+ "direction": INCOMING,
+ "relationship_type": "KNOWS",
+ "model": None,
+ },
+ )
+
+ AsyncTraversal(
+ muckefuck,
+ "a_name",
+ {
+ "node_class": Supplier,
+ "direction": INCOMING,
+ "relation_type": "KNOWS",
+ "model": None,
+ },
+ )
+
+
+@mark_async_test
+async def test_empty_filters():
+ """Test this case:
+ ```
+ SomeModel.nodes.filter().filter(Q(arg1=val1)).all()
+ SomeModel.nodes.exclude().exclude(Q(arg1=val1)).all()
+ SomeModel.nodes.filter().filter(arg1=val1).all()
+ ```
+ In django_rest_framework filter uses such as lazy function and
+ ``get_queryset`` function in ``GenericAPIView`` should returns
+ ``NodeSet`` object.
+ """
+
+ for c in await Coffee.nodes:
+ await c.delete()
+
+ c1 = await Coffee(name="Super", price=5, id_=1).save()
+ c2 = await Coffee(name="Puper", price=10, id_=2).save()
+
+ empty_filter = Coffee.nodes.filter()
+
+ all_coffees = await empty_filter.all()
+ assert len(all_coffees) == 2, "unexpected number of results"
+
+ filter_empty_filter = empty_filter.filter(price=5)
+ assert len(await filter_empty_filter.all()) == 1, "unexpected number of results"
+ assert (
+ c1 in await filter_empty_filter.all()
+ ), "doesnt contain c1 in ``filter_empty_filter``"
+
+ filter_q_empty_filter = empty_filter.filter(Q(price=5))
+ assert len(await filter_empty_filter.all()) == 1, "unexpected number of results"
+ assert (
+ c1 in await filter_empty_filter.all()
+ ), "doesnt contain c1 in ``filter_empty_filter``"
+
+
+@mark_async_test
+async def test_q_filters():
+ # Test where no children and self.connector != conn ?
+ for c in await Coffee.nodes:
+ await c.delete()
+
+ c1 = await Coffee(name="Icelands finest", price=5, id_=1).save()
+ c2 = await Coffee(name="Britains finest", price=10, id_=2).save()
+ c3 = await Coffee(name="Japans finest", price=35, id_=3).save()
+ c4 = await Coffee(name="US extra-fine", price=None, id_=4).save()
+ c5 = await Coffee(name="Latte", price=35, id_=5).save()
+ c6 = await Coffee(name="Cappuccino", price=35, id_=6).save()
+
+ coffees_5_10 = await Coffee.nodes.filter(Q(price=10) | Q(price=5)).all()
+ assert len(coffees_5_10) == 2, "unexpected number of results"
+ assert c1 in coffees_5_10, "doesnt contain 5 price coffee"
+ assert c2 in coffees_5_10, "doesnt contain 10 price coffee"
+
+ coffees_5_6 = (
+ await Coffee.nodes.filter(Q(name="Latte") | Q(name="Cappuccino"))
+ .filter(price=35)
+ .all()
+ )
+ assert len(coffees_5_6) == 2, "unexpected number of results"
+ assert c5 in coffees_5_6, "doesnt contain 5 coffee"
+ assert c6 in coffees_5_6, "doesnt contain 6 coffee"
+
+ coffees_5_6 = (
+ await Coffee.nodes.filter(price=35)
+ .filter(Q(name="Latte") | Q(name="Cappuccino"))
+ .all()
+ )
+ assert len(coffees_5_6) == 2, "unexpected number of results"
+ assert c5 in coffees_5_6, "doesnt contain 5 coffee"
+ assert c6 in coffees_5_6, "doesnt contain 6 coffee"
+
+ finest_coffees = await Coffee.nodes.filter(name__iendswith=" Finest").all()
+ assert len(finest_coffees) == 3, "unexpected number of results"
+ assert c1 in finest_coffees, "doesnt contain 1st finest coffee"
+ assert c2 in finest_coffees, "doesnt contain 2nd finest coffee"
+ assert c3 in finest_coffees, "doesnt contain 3rd finest coffee"
+
+ unpriced_coffees = await Coffee.nodes.filter(Q(price__isnull=True)).all()
+ assert len(unpriced_coffees) == 1, "unexpected number of results"
+ assert c4 in unpriced_coffees, "doesnt contain unpriced coffee"
+
+ coffees_with_id_gte_3 = await Coffee.nodes.filter(Q(id___gte=3)).all()
+ assert len(coffees_with_id_gte_3) == 4, "unexpected number of results"
+ assert c3 in coffees_with_id_gte_3
+ assert c4 in coffees_with_id_gte_3
+ assert c5 in coffees_with_id_gte_3
+ assert c6 in coffees_with_id_gte_3
+
+ coffees_5_not_japans = await Coffee.nodes.filter(
+ Q(price__gt=5) & ~Q(name="Japans finest")
+ ).all()
+ assert c3 not in coffees_5_not_japans
+
+ empty_Q_condition = await Coffee.nodes.filter(Q(price=5) | Q()).all()
+ assert (
+ len(empty_Q_condition) == 1
+ ), "undefined Q leading to unexpected number of results"
+ assert c1 in empty_Q_condition
+
+ combined_coffees = await Coffee.nodes.filter(
+ Q(price=35), Q(name="Latte") | Q(name="Cappuccino")
+ ).all()
+ assert len(combined_coffees) == 2
+ assert c5 in combined_coffees
+ assert c6 in combined_coffees
+ assert c3 not in combined_coffees
+
+ class QQ:
+ pass
+
+ with raises(TypeError):
+ wrong_Q = await Coffee.nodes.filter(Q(price=5) | QQ()).all()
+
+
+def test_qbase():
+ test_print_out = str(Q(price=5) | Q(price=10))
+ test_repr = repr(Q(price=5) | Q(price=10))
+ assert test_print_out == "(OR: ('price', 5), ('price', 10))"
+ assert test_repr == ""
+
+ assert ("price", 5) in (Q(price=5) | Q(price=10))
+
+ test_hash = set([Q(price_lt=30) | ~Q(price=5), Q(price_lt=30) | ~Q(price=5)])
+ assert len(test_hash) == 1
+
+
+@mark_async_test
+async def test_traversal_filter_left_hand_statement():
+ nescafe = await Coffee(name="Nescafe2", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe gold", price=11).save()
+
+ tesco = await Supplier(name="Sainsburys", delivery_cost=3).save()
+ biedronka = await Supplier(name="Biedronka", delivery_cost=5).save()
+ lidl = await Supplier(name="Lidl", delivery_cost=3).save()
+
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(biedronka)
+ await nescafe_gold.suppliers.connect(lidl)
+
+ lidl_supplier = (
+ await AsyncNodeSet(Coffee.nodes.filter(price=11).suppliers)
+ .filter(delivery_cost=3)
+ .all()
+ )
+
+ assert lidl in lidl_supplier
+
+
+@mark_async_test
+async def test_fetch_relations():
+ arabica = await Species(name="Arabica").save()
+ robusta = await Species(name="Robusta").save()
+ nescafe = await Coffee(name="Nescafe 1000", price=99).save()
+ nescafe_gold = await Coffee(name="Nescafe 1001", price=11).save()
+
+ tesco = await Supplier(name="Sainsburys", delivery_cost=3).save()
+ await nescafe.suppliers.connect(tesco)
+ await nescafe_gold.suppliers.connect(tesco)
+ await nescafe.species.connect(arabica)
+
+ result = (
+ await Supplier.nodes.filter(name="Sainsburys")
+ .fetch_relations("coffees__species")
+ .all()
+ )
+ assert arabica in result[0]
+ assert robusta not in result[0]
+ assert tesco in result[0]
+ assert nescafe in result[0]
+ assert nescafe_gold not in result[0]
+
+ result = (
+ await Species.nodes.filter(name="Robusta")
+ .fetch_relations(Optional("coffees__suppliers"))
+ .all()
+ )
+ assert result[0][0] is None
+
+ if AsyncUtil.is_async_code:
+ count = (
+ await Supplier.nodes.filter(name="Sainsburys")
+ .fetch_relations("coffees__species")
+ .get_len()
+ )
+ assert count == 1
+
+ assert (
+ await Supplier.nodes.fetch_relations("coffees__species")
+ .filter(name="Sainsburys")
+ .check_contains(tesco)
+ )
+ else:
+ count = len(
+ Supplier.nodes.filter(name="Sainsburys")
+ .fetch_relations("coffees__species")
+ .all()
+ )
+ assert count == 1
+
+ assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter(
+ name="Sainsburys"
+ )
diff --git a/test/async_/test_migration_neo4j_5.py b/test/async_/test_migration_neo4j_5.py
new file mode 100644
index 00000000..48c0e8c4
--- /dev/null
+++ b/test/async_/test_migration_neo4j_5.py
@@ -0,0 +1,80 @@
+from test._async_compat import mark_async_test
+
+import pytest
+
+from neomodel import (
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ IntegerProperty,
+ StringProperty,
+ adb,
+)
+
+
+class Album(AsyncStructuredNode):
+ name = StringProperty()
+
+
+class Released(AsyncStructuredRel):
+ year = IntegerProperty()
+
+
+class Band(AsyncStructuredNode):
+ name = StringProperty()
+ released = AsyncRelationshipTo(Album, relation_type="RELEASED", model=Released)
+
+
+@mark_async_test
+async def test_read_elements_id():
+ the_hives = await Band(name="The Hives").save()
+ lex_hives = await Album(name="Lex Hives").save()
+ released_rel = await the_hives.released.connect(lex_hives)
+
+ # Validate element_id properties
+ assert lex_hives.element_id == (await the_hives.released.single()).element_id
+ assert released_rel._start_node_element_id == the_hives.element_id
+ assert released_rel._end_node_element_id == lex_hives.element_id
+
+ # Validate id properties
+ # Behaviour is dependent on Neo4j version
+ db_version = await adb.database_version
+ if db_version.startswith("4"):
+ # Nodes' ids
+ assert lex_hives.id == int(lex_hives.element_id)
+ assert lex_hives.id == (await the_hives.released.single()).id
+ # Relationships' ids
+ assert isinstance(released_rel.element_id, str)
+ assert int(released_rel.element_id) == released_rel.id
+ assert released_rel._start_node_id == int(the_hives.element_id)
+ assert released_rel._end_node_id == int(lex_hives.element_id)
+ else:
+ # Nodes' ids
+ expected_error_type = ValueError
+ expected_error_message = "id is deprecated in Neo4j version 5, please migrate to element_id\. If you use the id in a Cypher query, replace id\(\) by elementId\(\)\."
+ assert isinstance(lex_hives.element_id, str)
+ with pytest.raises(
+ expected_error_type,
+ match=expected_error_message,
+ ):
+ lex_hives.id
+
+ # Relationships' ids
+ assert isinstance(released_rel.element_id, str)
+ assert isinstance(released_rel._start_node_element_id, str)
+ assert isinstance(released_rel._end_node_element_id, str)
+ with pytest.raises(
+ expected_error_type,
+ match=expected_error_message,
+ ):
+ released_rel.id
+ with pytest.raises(
+ expected_error_type,
+ match=expected_error_message,
+ ):
+ released_rel._start_node_id
+ with pytest.raises(
+ expected_error_type,
+ match=expected_error_message,
+ ):
+ released_rel._end_node_id
diff --git a/test/async_/test_models.py b/test/async_/test_models.py
new file mode 100644
index 00000000..b9bb2e44
--- /dev/null
+++ b/test/async_/test_models.py
@@ -0,0 +1,360 @@
+from __future__ import print_function
+
+from datetime import datetime
+from test._async_compat import mark_async_test
+
+from pytest import raises
+
+from neomodel import (
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ DateProperty,
+ IntegerProperty,
+ StringProperty,
+ adb,
+)
+from neomodel.exceptions import RequiredProperty, UniqueProperty
+
+
+class User(AsyncStructuredNode):
+ email = StringProperty(unique_index=True, required=True)
+ age = IntegerProperty(index=True)
+
+ @property
+ def email_alias(self):
+ return self.email
+
+ @email_alias.setter # noqa
+ def email_alias(self, value):
+ self.email = value
+
+
+class NodeWithoutProperty(AsyncStructuredNode):
+ pass
+
+
+@mark_async_test
+async def test_issue_233():
+ class BaseIssue233(AsyncStructuredNode):
+ __abstract_node__ = True
+
+ def __getitem__(self, item):
+ return self.__dict__[item]
+
+ class Issue233(BaseIssue233):
+ uid = StringProperty(unique_index=True, required=True)
+
+ i = await Issue233(uid="testgetitem").save()
+ assert i["uid"] == "testgetitem"
+
+
+def test_issue_72():
+ user = User(email="foo@bar.com")
+ assert user.age is None
+
+
+@mark_async_test
+async def test_required():
+ with raises(RequiredProperty):
+ await User(age=3).save()
+
+
+def test_repr_and_str():
+ u = User(email="robin@test.com", age=3)
+ assert repr(u) == ""
+ assert str(u) == "{'email': 'robin@test.com', 'age': 3}"
+
+
+@mark_async_test
+async def test_get_and_get_or_none():
+ u = User(email="robin@test.com", age=3)
+ assert await u.save()
+ rob = await User.nodes.get(email="robin@test.com")
+ assert rob.email == "robin@test.com"
+ assert rob.age == 3
+
+ rob = await User.nodes.get_or_none(email="robin@test.com")
+ assert rob.email == "robin@test.com"
+
+ n = await User.nodes.get_or_none(email="robin@nothere.com")
+ assert n is None
+
+
+@mark_async_test
+async def test_first_and_first_or_none():
+ u = User(email="matt@test.com", age=24)
+ assert await u.save()
+ u2 = User(email="tbrady@test.com", age=40)
+ assert await u2.save()
+ tbrady = await User.nodes.order_by("-age").first()
+ assert tbrady.email == "tbrady@test.com"
+ assert tbrady.age == 40
+
+ tbrady = await User.nodes.order_by("-age").first_or_none()
+ assert tbrady.email == "tbrady@test.com"
+
+ n = await User.nodes.first_or_none(email="matt@nothere.com")
+ assert n is None
+
+
+def test_bare_init_without_save():
+ """
+ If a node model is initialised without being saved, accessing its `element_id` should
+ return None.
+ """
+ assert User().element_id is None
+
+
+@mark_async_test
+async def test_save_to_model():
+ u = User(email="jim@test.com", age=3)
+ assert await u.save()
+ assert u.element_id is not None
+ assert u.email == "jim@test.com"
+ assert u.age == 3
+
+
+@mark_async_test
+async def test_save_node_without_properties():
+ n = NodeWithoutProperty()
+ assert await n.save()
+ assert n.element_id is not None
+
+
+@mark_async_test
+async def test_unique():
+ await adb.install_labels(User)
+ await User(email="jim1@test.com", age=3).save()
+ with raises(UniqueProperty):
+ await User(email="jim1@test.com", age=3).save()
+
+
+@mark_async_test
+async def test_update_unique():
+ u = await User(email="jimxx@test.com", age=3).save()
+ await u.save() # this shouldn't fail
+
+
+@mark_async_test
+async def test_update():
+ user = await User(email="jim2@test.com", age=3).save()
+ assert user
+ user.email = "jim2000@test.com"
+ await user.save()
+ jim = await User.nodes.get(email="jim2000@test.com")
+ assert jim
+ assert jim.email == "jim2000@test.com"
+
+
+@mark_async_test
+async def test_save_through_magic_property():
+ user = await User(email_alias="blah@test.com", age=8).save()
+ assert user.email_alias == "blah@test.com"
+ user = await User.nodes.get(email="blah@test.com")
+ assert user.email == "blah@test.com"
+ assert user.email_alias == "blah@test.com"
+
+ user1 = await User(email="blah1@test.com", age=8).save()
+ assert user1.email_alias == "blah1@test.com"
+ user1.email_alias = "blah2@test.com"
+ assert await user1.save()
+ user2 = await User.nodes.get(email="blah2@test.com")
+ assert user2
+
+
+class Customer2(AsyncStructuredNode):
+ __label__ = "customers"
+ email = StringProperty(unique_index=True, required=True)
+ age = IntegerProperty(index=True)
+
+
+@mark_async_test
+async def test_not_updated_on_unique_error():
+ await adb.install_labels(Customer2)
+ await Customer2(email="jim@bob.com", age=7).save()
+ test = await Customer2(email="jim1@bob.com", age=2).save()
+ test.email = "jim@bob.com"
+ with raises(UniqueProperty):
+ await test.save()
+ customers = await Customer2.nodes
+ assert customers[0].email != customers[1].email
+ assert (await Customer2.nodes.get(email="jim@bob.com")).age == 7
+ assert (await Customer2.nodes.get(email="jim1@bob.com")).age == 2
+
+
+@mark_async_test
+async def test_label_not_inherited():
+ class Customer3(Customer2):
+ address = StringProperty()
+
+ assert Customer3.__label__ == "Customer3"
+ c = await Customer3(email="test@test.com").save()
+ assert "customers" in await c.labels()
+ assert "Customer3" in await c.labels()
+
+ c = await Customer2.nodes.get(email="test@test.com")
+ assert isinstance(c, Customer2)
+ assert "customers" in await c.labels()
+ assert "Customer3" in await c.labels()
+
+
+@mark_async_test
+async def test_refresh():
+ c = await Customer2(email="my@email.com", age=16).save()
+ c.my_custom_prop = "value"
+ copy = await Customer2.nodes.get(email="my@email.com")
+ copy.age = 20
+ await copy.save()
+
+ assert c.age == 16
+
+ await c.refresh()
+ assert c.age == 20
+ assert c.my_custom_prop == "value"
+
+ c = Customer2.inflate(c.element_id)
+ c.age = 30
+ await c.refresh()
+
+ assert c.age == 20
+
+ _db_version = await adb.database_version
+ if _db_version.startswith("4"):
+ c = Customer2.inflate(999)
+ else:
+ c = Customer2.inflate("4:xxxxxx:999")
+ with raises(Customer2.DoesNotExist):
+ await c.refresh()
+
+
+@mark_async_test
+async def test_setting_value_to_none():
+ c = await Customer2(email="alice@bob.com", age=42).save()
+ assert c.age is not None
+
+ c.age = None
+ await c.save()
+
+ copy = await Customer2.nodes.get(email="alice@bob.com")
+ assert copy.age is None
+
+
+@mark_async_test
+async def test_inheritance():
+ class User(AsyncStructuredNode):
+ __abstract_node__ = True
+ name = StringProperty(unique_index=True)
+
+ class Shopper(User):
+ balance = IntegerProperty(index=True)
+
+ async def credit_account(self, amount):
+ self.balance = self.balance + int(amount)
+ await self.save()
+
+ jim = await Shopper(name="jimmy", balance=300).save()
+ await jim.credit_account(50)
+
+ assert Shopper.__label__ == "Shopper"
+ assert jim.balance == 350
+ assert len(jim.inherited_labels()) == 1
+ assert len(await jim.labels()) == 1
+ assert (await jim.labels())[0] == "Shopper"
+
+
+@mark_async_test
+async def test_inherited_optional_labels():
+ class BaseOptional(AsyncStructuredNode):
+ __optional_labels__ = ["Alive"]
+ name = StringProperty(unique_index=True)
+
+ class ExtendedOptional(BaseOptional):
+ __optional_labels__ = ["RewardsMember"]
+ balance = IntegerProperty(index=True)
+
+ async def credit_account(self, amount):
+ self.balance = self.balance + int(amount)
+ await self.save()
+
+ henry = await ExtendedOptional(name="henry", balance=300).save()
+ await henry.credit_account(50)
+
+ assert ExtendedOptional.__label__ == "ExtendedOptional"
+ assert henry.balance == 350
+ assert len(henry.inherited_labels()) == 2
+ assert len(await henry.labels()) == 2
+
+ assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"}
+
+
+@mark_async_test
+async def test_mixins():
+ class UserMixin:
+ name = StringProperty(unique_index=True)
+ password = StringProperty()
+
+ class CreditMixin:
+ balance = IntegerProperty(index=True)
+
+ async def credit_account(self, amount):
+ self.balance = self.balance + int(amount)
+ await self.save()
+
+ class Shopper2(AsyncStructuredNode, UserMixin, CreditMixin):
+ pass
+
+ jim = await Shopper2(name="jimmy", balance=300).save()
+ await jim.credit_account(50)
+
+ assert Shopper2.__label__ == "Shopper2"
+ assert jim.balance == 350
+ assert len(jim.inherited_labels()) == 1
+ assert len(await jim.labels()) == 1
+ assert (await jim.labels())[0] == "Shopper2"
+
+
+@mark_async_test
+async def test_date_property():
+ class DateTest(AsyncStructuredNode):
+ birthdate = DateProperty()
+
+ user = await DateTest(birthdate=datetime.now()).save()
+
+
+def test_reserved_property_keys():
+ error_match = r".*is not allowed as it conflicts with neomodel internals.*"
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesDeletedNode(AsyncStructuredNode):
+ deleted = StringProperty()
+
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesIdNode(AsyncStructuredNode):
+ id = StringProperty()
+
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesElementIdNode(AsyncStructuredNode):
+ element_id = StringProperty()
+
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesIdRel(AsyncStructuredRel):
+ id = StringProperty()
+
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesElementIdRel(AsyncStructuredRel):
+ element_id = StringProperty()
+
+ error_match = r"Property names 'source' and 'target' are not allowed as they conflict with neomodel internals."
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesSourceRel(AsyncStructuredRel):
+ source = StringProperty()
+
+ with raises(ValueError, match=error_match):
+
+ class ReservedPropertiesTargetRel(AsyncStructuredRel):
+ target = StringProperty()
diff --git a/test/async_/test_multiprocessing.py b/test/async_/test_multiprocessing.py
new file mode 100644
index 00000000..9bf46598
--- /dev/null
+++ b/test/async_/test_multiprocessing.py
@@ -0,0 +1,24 @@
+from multiprocessing.pool import ThreadPool as Pool
+from test._async_compat import mark_async_test
+
+from neomodel import AsyncStructuredNode, StringProperty, adb
+
+
+class ThingyMaBob(AsyncStructuredNode):
+ name = StringProperty(unique_index=True, required=True)
+
+
+async def thing_create(name):
+ name = str(name)
+ (thing,) = await ThingyMaBob.get_or_create({"name": name})
+ return thing.name, name
+
+
+@mark_async_test
+async def test_concurrency():
+ with Pool(5) as p:
+ results = p.map(thing_create, range(50))
+ for to_unpack in results:
+ returned, sent = await to_unpack
+ assert returned == sent
+ await adb.close_connection()
diff --git a/test/async_/test_paths.py b/test/async_/test_paths.py
new file mode 100644
index 00000000..59a5e385
--- /dev/null
+++ b/test/async_/test_paths.py
@@ -0,0 +1,96 @@
+from test._async_compat import mark_async_test
+
+from neomodel import (
+ AsyncNeomodelPath,
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ IntegerProperty,
+ StringProperty,
+ UniqueIdProperty,
+ adb,
+)
+
+
+class PersonLivesInCity(AsyncStructuredRel):
+ """
+ Relationship with data that will be instantiated as "stand-alone"
+ """
+
+ some_num = IntegerProperty(index=True, default=12)
+
+
+class CountryOfOrigin(AsyncStructuredNode):
+ code = StringProperty(unique_index=True, required=True)
+
+
+class CityOfResidence(AsyncStructuredNode):
+ name = StringProperty(required=True)
+ country = AsyncRelationshipTo(CountryOfOrigin, "FROM_COUNTRY")
+
+
+class PersonOfInterest(AsyncStructuredNode):
+ uid = UniqueIdProperty()
+ name = StringProperty(unique_index=True)
+ age = IntegerProperty(index=True, default=0)
+
+ country = AsyncRelationshipTo(CountryOfOrigin, "IS_FROM")
+ city = AsyncRelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity)
+
+
+@mark_async_test
+async def test_path_instantiation():
+ """
+ Neo4j driver paths should be instantiated as neomodel paths, with all of
+ their nodes and relationships resolved to their Python objects wherever
+ such a mapping is available.
+ """
+
+ c1 = await CountryOfOrigin(code="GR").save()
+ c2 = await CountryOfOrigin(code="FR").save()
+
+ ct1 = await CityOfResidence(name="Athens", country=c1).save()
+ ct2 = await CityOfResidence(name="Paris", country=c2).save()
+
+ p1 = await PersonOfInterest(name="Bill", age=22).save()
+ await p1.country.connect(c1)
+ await p1.city.connect(ct1)
+
+ p2 = await PersonOfInterest(name="Jean", age=28).save()
+ await p2.country.connect(c2)
+ await p2.city.connect(ct2)
+
+ p3 = await PersonOfInterest(name="Bo", age=32).save()
+ await p3.country.connect(c1)
+ await p3.city.connect(ct2)
+
+ p4 = await PersonOfInterest(name="Drop", age=16).save()
+ await p4.country.connect(c1)
+ await p4.city.connect(ct2)
+
+ # Retrieve a single path
+ q = await adb.cypher_query(
+ "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1",
+ resolve_objects=True,
+ )
+
+ path_object = q[0][0][0]
+ path_nodes = path_object.nodes
+ path_rels = path_object.relationships
+
+ assert type(path_object) is AsyncNeomodelPath
+ assert type(path_nodes[0]) is CityOfResidence
+ assert type(path_nodes[1]) is PersonOfInterest
+ assert type(path_nodes[2]) is CountryOfOrigin
+
+ assert type(path_rels[0]) is PersonLivesInCity
+ assert type(path_rels[1]) is AsyncStructuredRel
+
+ await c1.delete()
+ await c2.delete()
+ await ct1.delete()
+ await ct2.delete()
+ await p1.delete()
+ await p2.delete()
+ await p3.delete()
+ await p4.delete()
diff --git a/test/async_/test_properties.py b/test/async_/test_properties.py
new file mode 100644
index 00000000..0679cf89
--- /dev/null
+++ b/test/async_/test_properties.py
@@ -0,0 +1,453 @@
+from datetime import date, datetime
+from test._async_compat import mark_async_test
+
+from pytest import mark, raises
+from pytz import timezone
+
+from neomodel import AsyncStructuredNode, adb
+from neomodel.exceptions import (
+ DeflateError,
+ InflateError,
+ RequiredProperty,
+ UniqueProperty,
+)
+from neomodel.properties import (
+ ArrayProperty,
+ DateProperty,
+ DateTimeFormatProperty,
+ DateTimeProperty,
+ EmailProperty,
+ IntegerProperty,
+ JSONProperty,
+ NormalizedProperty,
+ RegexProperty,
+ StringProperty,
+ UniqueIdProperty,
+)
+from neomodel.util import _get_node_properties
+
+
+class FooBar:
+ pass
+
+
+def test_string_property_exceeds_max_length():
+ """
+ StringProperty is defined by two properties: `max_length` and `choices` that are mutually exclusive. Furthermore,
+ max_length must be a positive non-zero number.
+ """
+ # Try to define a property that has both choices and max_length
+ with raises(ValueError):
+ some_string_property = StringProperty(
+ choices={"One": "1", "Two": "2"}, max_length=22
+ )
+
+ # Try to define a string property that has a negative zero length
+ with raises(ValueError):
+ another_string_property = StringProperty(max_length=-35)
+
+ # Try to validate a long string
+ a_string_property = StringProperty(required=True, max_length=5)
+ with raises(ValueError):
+ a_string_property.normalize("The quick brown fox jumps over the lazy dog")
+
+ # Try to validate a "valid" string, as per the max_length setting.
+ valid_string = "Owen"
+ normalised_string = a_string_property.normalize(valid_string)
+ assert (
+ valid_string == normalised_string
+ ), "StringProperty max_length test passed but values do not match."
+
+
+@mark_async_test
+async def test_string_property_w_choice():
+ class TestChoices(AsyncStructuredNode):
+ SEXES = {"F": "Female", "M": "Male", "O": "Other"}
+ sex = StringProperty(required=True, choices=SEXES)
+
+ try:
+ await TestChoices(sex="Z").save()
+ except DeflateError as e:
+ assert "choice" in str(e)
+ else:
+ assert False, "DeflateError not raised."
+
+ node = await TestChoices(sex="M").save()
+ assert node.get_sex_display() == "Male"
+
+
+def test_deflate_inflate():
+ prop = IntegerProperty(required=True)
+ prop.name = "age"
+ prop.owner = FooBar
+
+ try:
+ prop.inflate("six")
+ except InflateError as e:
+ assert "inflate property" in str(e)
+ else:
+ assert False, "DeflateError not raised."
+
+ try:
+ prop.deflate("six")
+ except DeflateError as e:
+ assert "deflate property" in str(e)
+ else:
+ assert False, "DeflateError not raised."
+
+
+def test_datetimes_timezones():
+ prop = DateTimeProperty()
+ prop.name = "foo"
+ prop.owner = FooBar
+ t = datetime.utcnow()
+ gr = timezone("Europe/Athens")
+ gb = timezone("Europe/London")
+ dt1 = gr.localize(t)
+ dt2 = gb.localize(t)
+ time1 = prop.inflate(prop.deflate(dt1))
+ time2 = prop.inflate(prop.deflate(dt2))
+ assert time1.utctimetuple() == dt1.utctimetuple()
+ assert time1.utctimetuple() < time2.utctimetuple()
+ assert time1.tzname() == "UTC"
+
+
+def test_date():
+ prop = DateProperty()
+ prop.name = "foo"
+ prop.owner = FooBar
+ somedate = date(2012, 12, 15)
+ assert prop.deflate(somedate) == "2012-12-15"
+ assert prop.inflate("2012-12-15") == somedate
+
+
+def test_datetime_format():
+ some_format = "%Y-%m-%d %H:%M:%S"
+ prop = DateTimeFormatProperty(format=some_format)
+ prop.name = "foo"
+ prop.owner = FooBar
+ some_datetime = datetime(2019, 3, 19, 15, 36, 25)
+ assert prop.deflate(some_datetime) == "2019-03-19 15:36:25"
+ assert prop.inflate("2019-03-19 15:36:25") == some_datetime
+
+
+def test_datetime_exceptions():
+ prop = DateTimeProperty()
+ prop.name = "created"
+ prop.owner = FooBar
+ faulty = "dgdsg"
+
+ try:
+ prop.inflate(faulty)
+ except InflateError as e:
+ assert "inflate property" in str(e)
+ else:
+ assert False, "InflateError not raised."
+
+ try:
+ prop.deflate(faulty)
+ except DeflateError as e:
+ assert "deflate property" in str(e)
+ else:
+ assert False, "DeflateError not raised."
+
+
+def test_date_exceptions():
+ prop = DateProperty()
+ prop.name = "date"
+ prop.owner = FooBar
+ faulty = "2012-14-13"
+
+ try:
+ prop.inflate(faulty)
+ except InflateError as e:
+ assert "inflate property" in str(e)
+ else:
+ assert False, "InflateError not raised."
+
+ try:
+ prop.deflate(faulty)
+ except DeflateError as e:
+ assert "deflate property" in str(e)
+ else:
+ assert False, "DeflateError not raised."
+
+
+def test_json():
+ prop = JSONProperty()
+ prop.name = "json"
+ prop.owner = FooBar
+
+ value = {"test": [1, 2, 3]}
+
+ assert prop.deflate(value) == '{"test": [1, 2, 3]}'
+ assert prop.inflate('{"test": [1, 2, 3]}') == value
+
+
+@mark_async_test
+async def test_default_value():
+ class DefaultTestValue(AsyncStructuredNode):
+ name_xx = StringProperty(default="jim", index=True)
+
+ a = DefaultTestValue()
+ assert a.name_xx == "jim"
+ await a.save()
+
+
+@mark_async_test
+async def test_default_value_callable():
+ def uid_generator():
+ return "xx"
+
+ class DefaultTestValueTwo(AsyncStructuredNode):
+ uid = StringProperty(default=uid_generator, index=True)
+
+ a = await DefaultTestValueTwo().save()
+ assert a.uid == "xx"
+
+
+@mark_async_test
+async def test_default_value_callable_type():
+ # check our object gets converted to str without serializing and reload
+ def factory():
+ class Foo:
+ def __str__(self):
+ return "123"
+
+ return Foo()
+
+ class DefaultTestValueThree(AsyncStructuredNode):
+ uid = StringProperty(default=factory, index=True)
+
+ x = DefaultTestValueThree()
+ assert x.uid == "123"
+ await x.save()
+ assert x.uid == "123"
+ await x.refresh()
+ assert x.uid == "123"
+
+
+@mark_async_test
+async def test_independent_property_name():
+ class TestDBNamePropertyNode(AsyncStructuredNode):
+ name_ = StringProperty(db_property="name")
+
+ x = TestDBNamePropertyNode()
+ x.name_ = "jim"
+ await x.save()
+
+ # check database property name on low level
+ results, meta = await adb.cypher_query("MATCH (n:TestDBNamePropertyNode) RETURN n")
+ node_properties = _get_node_properties(results[0][0])
+ assert node_properties["name"] == "jim"
+
+ node_properties = _get_node_properties(results[0][0])
+ assert not "name_" in node_properties
+ assert not hasattr(x, "name")
+ assert hasattr(x, "name_")
+ assert (await TestDBNamePropertyNode.nodes.filter(name_="jim").all())[
+ 0
+ ].name_ == x.name_
+ assert (await TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_
+
+ await x.delete()
+
+
+@mark_async_test
+async def test_independent_property_name_get_or_create():
+ class TestNode(AsyncStructuredNode):
+ uid = UniqueIdProperty()
+ name_ = StringProperty(db_property="name", required=True)
+
+ # create the node
+ await TestNode.get_or_create({"uid": 123, "name_": "jim"})
+ # test that the node is retrieved correctly
+ x = (await TestNode.get_or_create({"uid": 123, "name_": "jim"}))[0]
+
+ # check database property name on low level
+ results, _ = await adb.cypher_query("MATCH (n:TestNode) RETURN n")
+ node_properties = _get_node_properties(results[0][0])
+ assert node_properties["name"] == "jim"
+ assert "name_" not in node_properties
+
+ # delete node afterwards
+ await x.delete()
+
+
+@mark.parametrize("normalized_class", (NormalizedProperty,))
+def test_normalized_property(normalized_class):
+ class TestProperty(normalized_class):
+ def normalize(self, value):
+ self._called_with = value
+ self._called = True
+ return value + "bar"
+
+ inflate = TestProperty()
+ inflate_res = inflate.inflate("foo")
+ assert getattr(inflate, "_called", False)
+ assert getattr(inflate, "_called_with", None) == "foo"
+ assert inflate_res == "foobar"
+
+ deflate = TestProperty()
+ deflate_res = deflate.deflate("bar")
+ assert getattr(deflate, "_called", False)
+ assert getattr(deflate, "_called_with", None) == "bar"
+ assert deflate_res == "barbar"
+
+ default = TestProperty(default="qux")
+ default_res = default.default_value()
+ assert getattr(default, "_called", False)
+ assert getattr(default, "_called_with", None) == "qux"
+ assert default_res == "quxbar"
+
+
+def test_regex_property():
+ class MissingExpression(RegexProperty):
+ pass
+
+ with raises(ValueError):
+ MissingExpression()
+
+ class TestProperty(RegexProperty):
+ name = "test"
+ owner = object()
+ expression = r"\w+ \w+$"
+
+ def normalize(self, value):
+ self._called = True
+ return super().normalize(value)
+
+ prop = TestProperty()
+ result = prop.inflate("foo bar")
+ assert getattr(prop, "_called", False)
+ assert result == "foo bar"
+
+ with raises(DeflateError):
+ prop.deflate("qux")
+
+
+def test_email_property():
+ prop = EmailProperty()
+ prop.name = "email"
+ prop.owner = object()
+ result = prop.inflate("foo@example.com")
+ assert result == "foo@example.com"
+
+ with raises(DeflateError):
+ prop.deflate("foo@example")
+
+
+@mark_async_test
+async def test_uid_property():
+ prop = UniqueIdProperty()
+ prop.name = "uid"
+ prop.owner = object()
+ myuid = prop.default_value()
+ assert len(myuid)
+
+ class CheckMyId(AsyncStructuredNode):
+ uid = UniqueIdProperty()
+
+ cmid = await CheckMyId().save()
+ assert len(cmid.uid)
+
+
+class ArrayProps(AsyncStructuredNode):
+ uid = StringProperty(unique_index=True)
+ untyped_arr = ArrayProperty()
+ typed_arr = ArrayProperty(IntegerProperty())
+
+
+@mark_async_test
+async def test_array_properties():
+ # untyped
+ ap1 = await ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save()
+ assert "Tim" in ap1.untyped_arr
+ ap1 = await ArrayProps.nodes.get(uid="1")
+ assert "Tim" in ap1.untyped_arr
+
+ # typed
+ try:
+ await ArrayProps(uid="2", typed_arr=["a", "b"]).save()
+ except DeflateError as e:
+ assert "unsaved node" in str(e)
+ else:
+ assert False, "DeflateError not raised."
+
+ ap2 = await ArrayProps(uid="2", typed_arr=[1, 2]).save()
+ assert 1 in ap2.typed_arr
+ ap2 = await ArrayProps.nodes.get(uid="2")
+ assert 2 in ap2.typed_arr
+
+
+def test_illegal_array_base_prop_raises():
+ with raises(ValueError):
+ ArrayProperty(StringProperty(index=True))
+
+
+@mark_async_test
+async def test_indexed_array():
+ class IndexArray(AsyncStructuredNode):
+ ai = ArrayProperty(unique_index=True)
+
+ b = await IndexArray(ai=[1, 2]).save()
+ c = await IndexArray.nodes.get(ai=[1, 2])
+ assert b.element_id == c.element_id
+
+
+@mark_async_test
+async def test_unique_index_prop_not_required():
+ class ConstrainedTestNode(AsyncStructuredNode):
+ required_property = StringProperty(required=True)
+ unique_property = StringProperty(unique_index=True)
+ unique_required_property = StringProperty(unique_index=True, required=True)
+ unconstrained_property = StringProperty()
+
+ # Create a node with a missing required property
+ with raises(RequiredProperty):
+ x = ConstrainedTestNode(required_property="required", unique_property="unique")
+ await x.save()
+
+ # Create a node with a missing unique (but not required) property.
+ x = ConstrainedTestNode()
+ x.required_property = "required"
+ x.unique_required_property = "unique and required"
+ x.unconstrained_property = "no contraints"
+ await x.save()
+
+ # check database property name on low level
+ results, meta = await adb.cypher_query("MATCH (n:ConstrainedTestNode) RETURN n")
+ node_properties = _get_node_properties(results[0][0])
+ assert node_properties["unique_required_property"] == "unique and required"
+
+ # delete node afterwards
+ await x.delete()
+
+
+@mark_async_test
+async def test_unique_index_prop_enforced():
+ class UniqueNullableNameNode(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+
+ await adb.install_labels(UniqueNullableNameNode)
+ # Nameless
+ x = UniqueNullableNameNode()
+ await x.save()
+ y = UniqueNullableNameNode()
+ await y.save()
+
+ # Named
+ z = UniqueNullableNameNode(name="named")
+ await z.save()
+ with raises(UniqueProperty):
+ a = UniqueNullableNameNode(name="named")
+ await a.save()
+
+ # Check nodes are in database
+ results, _ = await adb.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n")
+ assert len(results) == 3
+
+ # Delete nodes afterwards
+ await x.delete()
+ await y.delete()
+ await z.delete()
diff --git a/test/async_/test_relationship_models.py b/test/async_/test_relationship_models.py
new file mode 100644
index 00000000..95fe4714
--- /dev/null
+++ b/test/async_/test_relationship_models.py
@@ -0,0 +1,171 @@
+from datetime import datetime
+from test._async_compat import mark_async_test
+
+import pytz
+from pytest import raises
+
+from neomodel import (
+ AsyncRelationship,
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ DateTimeProperty,
+ DeflateError,
+ StringProperty,
+)
+from neomodel._async_compat.util import AsyncUtil
+
+HOOKS_CALLED = {"pre_save": 0, "post_save": 0}
+
+
+class FriendRel(AsyncStructuredRel):
+ since = DateTimeProperty(default=lambda: datetime.now(pytz.utc))
+
+
+class HatesRel(FriendRel):
+ reason = StringProperty()
+
+ def pre_save(self):
+ HOOKS_CALLED["pre_save"] += 1
+
+ def post_save(self):
+ HOOKS_CALLED["post_save"] += 1
+
+
+class Badger(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ friend = AsyncRelationship("Badger", "FRIEND", model=FriendRel)
+ hates = AsyncRelationshipTo("Stoat", "HATES", model=HatesRel)
+
+
+class Stoat(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ hates = AsyncRelationshipTo("Badger", "HATES", model=HatesRel)
+
+
+@mark_async_test
+async def test_either_connect_with_rel_model():
+ paul = await Badger(name="Paul").save()
+ tom = await Badger(name="Tom").save()
+
+ # creating rels
+ new_rel = await tom.friend.disconnect(paul)
+ new_rel = await tom.friend.connect(paul)
+ assert isinstance(new_rel, FriendRel)
+ assert isinstance(new_rel.since, datetime)
+
+ # updating properties
+ new_rel.since = datetime.now(pytz.utc)
+ assert isinstance(await new_rel.save(), FriendRel)
+
+ # start and end nodes are the opposite of what you'd expect when using either..
+ # I've tried everything possible to correct this to no avail
+ paul = await new_rel.start_node()
+ tom = await new_rel.end_node()
+ assert paul.name == "Tom"
+ assert tom.name == "Paul"
+
+
+@mark_async_test
+async def test_direction_connect_with_rel_model():
+ paul = await Badger(name="Paul the badger").save()
+ ian = await Stoat(name="Ian the stoat").save()
+
+ rel = await ian.hates.connect(
+ paul, {"reason": "thinks paul should bath more often"}
+ )
+ assert isinstance(rel.since, datetime)
+ assert isinstance(rel, FriendRel)
+ assert rel.reason.startswith("thinks")
+ rel.reason = "he smells"
+ await rel.save()
+
+ ian = await rel.start_node()
+ assert isinstance(ian, Stoat)
+ paul = await rel.end_node()
+ assert isinstance(paul, Badger)
+
+ assert ian.name.startswith("Ian")
+ assert paul.name.startswith("Paul")
+
+ rel = await ian.hates.relationship(paul)
+ assert isinstance(rel, HatesRel)
+ assert isinstance(rel.since, datetime)
+ await rel.save()
+
+ # test deflate checking
+ rel.since = "2:30pm"
+ with raises(DeflateError):
+ await rel.save()
+
+ # check deflate check via connect
+ with raises(DeflateError):
+ await paul.hates.connect(
+ ian,
+ {
+ "reason": "thinks paul should bath more often",
+ "since": "2:30pm",
+ },
+ )
+
+
+@mark_async_test
+async def test_traversal_where_clause():
+ phill = await Badger(name="Phill the badger").save()
+ tim = await Badger(name="Tim the badger").save()
+ bob = await Badger(name="Bob the badger").save()
+ rel = await tim.friend.connect(bob)
+ now = datetime.now(pytz.utc)
+ assert rel.since < now
+ rel2 = await tim.friend.connect(phill)
+ assert rel2.since > now
+ friends = tim.friend.match(since__gt=now)
+ assert len(await friends.all()) == 1
+
+
+@mark_async_test
+async def test_multiple_rels_exist_issue_223():
+ # check a badger can dislike a stoat for multiple reasons
+ phill = await Badger(name="Phill").save()
+ ian = await Stoat(name="Stoat").save()
+
+ rel_a = await phill.hates.connect(ian, {"reason": "a"})
+ rel_b = await phill.hates.connect(ian, {"reason": "b"})
+ assert rel_a.element_id != rel_b.element_id
+
+ if AsyncUtil.is_async_code:
+ ian_a = (await phill.hates.match(reason="a"))[0]
+ ian_b = (await phill.hates.match(reason="b"))[0]
+ else:
+ ian_a = phill.hates.match(reason="a")[0]
+ ian_b = phill.hates.match(reason="b")[0]
+ assert ian_a.element_id == ian_b.element_id
+
+
+@mark_async_test
+async def test_retrieve_all_rels():
+ tom = await Badger(name="tom").save()
+ ian = await Stoat(name="ian").save()
+
+ rel_a = await tom.hates.connect(ian, {"reason": "a"})
+ rel_b = await tom.hates.connect(ian, {"reason": "b"})
+
+ rels = await tom.hates.all_relationships(ian)
+ assert len(rels) == 2
+ assert rels[0].element_id in [rel_a.element_id, rel_b.element_id]
+ assert rels[1].element_id in [rel_a.element_id, rel_b.element_id]
+
+
+@mark_async_test
+async def test_save_hook_on_rel_model():
+ HOOKS_CALLED["pre_save"] = 0
+ HOOKS_CALLED["post_save"] = 0
+
+ paul = await Badger(name="PaulB").save()
+ ian = await Stoat(name="IanS").save()
+
+ rel = await ian.hates.connect(paul, {"reason": "yadda yadda"})
+ await rel.save()
+
+ assert HOOKS_CALLED["pre_save"] == 2
+ assert HOOKS_CALLED["post_save"] == 2
diff --git a/test/async_/test_relationships.py b/test/async_/test_relationships.py
new file mode 100644
index 00000000..9835e8b7
--- /dev/null
+++ b/test/async_/test_relationships.py
@@ -0,0 +1,209 @@
+from test._async_compat import mark_async_test
+
+from pytest import raises
+
+from neomodel import (
+ AsyncOne,
+ AsyncRelationship,
+ AsyncRelationshipFrom,
+ AsyncRelationshipTo,
+ AsyncStructuredNode,
+ AsyncStructuredRel,
+ IntegerProperty,
+ Q,
+ StringProperty,
+ adb,
+)
+
+
+class PersonWithRels(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+ age = IntegerProperty(index=True)
+ is_from = AsyncRelationshipTo("Country", "IS_FROM")
+ knows = AsyncRelationship("PersonWithRels", "KNOWS")
+
+ @property
+ def special_name(self):
+ return self.name
+
+ def special_power(self):
+ return "I have no powers"
+
+
+class Country(AsyncStructuredNode):
+ code = StringProperty(unique_index=True)
+ inhabitant = AsyncRelationshipFrom(PersonWithRels, "IS_FROM")
+ president = AsyncRelationshipTo(PersonWithRels, "PRESIDENT", cardinality=AsyncOne)
+
+
+class SuperHero(PersonWithRels):
+ power = StringProperty(index=True)
+
+ def special_power(self):
+ return "I have powers"
+
+
+@mark_async_test
+async def test_actions_on_deleted_node():
+ u = await PersonWithRels(name="Jim2", age=3).save()
+ await u.delete()
+ with raises(ValueError):
+ await u.is_from.connect(None)
+
+ with raises(ValueError):
+ await u.is_from.get()
+
+ with raises(ValueError):
+ await u.save()
+
+
+@mark_async_test
+async def test_bidirectional_relationships():
+ u = await PersonWithRels(name="Jim", age=3).save()
+ assert u
+
+ de = await Country(code="DE").save()
+ assert de
+
+ assert not await u.is_from.all()
+
+ assert u.is_from.__class__.__name__ == "AsyncZeroOrMore"
+ await u.is_from.connect(de)
+
+ assert len(await u.is_from.all()) == 1
+
+ assert await u.is_from.is_connected(de)
+
+ b = (await u.is_from.all())[0]
+ assert b.__class__.__name__ == "Country"
+ assert b.code == "DE"
+
+ s = (await b.inhabitant.all())[0]
+ assert s.name == "Jim"
+
+ await u.is_from.disconnect(b)
+ assert not await u.is_from.is_connected(b)
+
+
+@mark_async_test
+async def test_either_direction_connect():
+ rey = await PersonWithRels(name="Rey", age=3).save()
+ sakis = await PersonWithRels(name="Sakis", age=3).save()
+
+ await rey.knows.connect(sakis)
+ assert await rey.knows.is_connected(sakis)
+ assert await sakis.knows.is_connected(rey)
+ await sakis.knows.connect(rey)
+
+ result, _ = await sakis.cypher(
+ f"""MATCH (us), (them)
+ WHERE {await adb.get_id_method()}(us)=$self and {await adb.get_id_method()}(them)=$them
+ MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""",
+ {"them": await adb.parse_element_id(rey.element_id)},
+ )
+ assert int(result[0][0]) == 1
+
+ rel = await rey.knows.relationship(sakis)
+ assert isinstance(rel, AsyncStructuredRel)
+
+ rels = await rey.knows.all_relationships(sakis)
+ assert isinstance(rels[0], AsyncStructuredRel)
+
+
+@mark_async_test
+async def test_search_and_filter_and_exclude():
+ fred = await PersonWithRels(name="Fred", age=13).save()
+ zz = await Country(code="ZZ").save()
+ zx = await Country(code="ZX").save()
+ zt = await Country(code="ZY").save()
+ await fred.is_from.connect(zz)
+ await fred.is_from.connect(zx)
+ await fred.is_from.connect(zt)
+ result = await fred.is_from.filter(code="ZX")
+ assert result[0].code == "ZX"
+
+ result = await fred.is_from.filter(code="ZY")
+ assert result[0].code == "ZY"
+
+ result = await fred.is_from.exclude(code="ZZ").exclude(code="ZY")
+ assert result[0].code == "ZX" and len(result) == 1
+
+ result = await fred.is_from.exclude(Q(code__contains="Y"))
+ assert len(result) == 2
+
+ result = await fred.is_from.filter(Q(code__contains="Z"))
+ assert len(result) == 3
+
+
+@mark_async_test
+async def test_custom_methods():
+ u = await PersonWithRels(name="Joe90", age=13).save()
+ assert u.special_power() == "I have no powers"
+ u = await SuperHero(name="Joe91", age=13, power="xxx").save()
+ assert u.special_power() == "I have powers"
+ assert u.special_name == "Joe91"
+
+
+@mark_async_test
+async def test_valid_reconnection():
+ p = await PersonWithRels(name="ElPresidente", age=93).save()
+ assert p
+
+ pp = await PersonWithRels(name="TheAdversary", age=33).save()
+ assert pp
+
+ c = await Country(code="CU").save()
+ assert c
+
+ await c.president.connect(p)
+ assert await c.president.is_connected(p)
+
+ # the coup d'etat
+ await c.president.reconnect(p, pp)
+ assert await c.president.is_connected(pp)
+
+ # reelection time
+ await c.president.reconnect(pp, pp)
+ assert await c.president.is_connected(pp)
+
+
+@mark_async_test
+async def test_valid_replace():
+ brady = await PersonWithRels(name="Tom Brady", age=40).save()
+ assert brady
+
+ gronk = await PersonWithRels(name="Rob Gronkowski", age=28).save()
+ assert gronk
+
+ colbert = await PersonWithRels(name="Stephen Colbert", age=53).save()
+ assert colbert
+
+ hanks = await PersonWithRels(name="Tom Hanks", age=61).save()
+ assert hanks
+
+ await brady.knows.connect(gronk)
+ await brady.knows.connect(colbert)
+ assert len(await brady.knows.all()) == 2
+ assert await brady.knows.is_connected(gronk)
+ assert await brady.knows.is_connected(colbert)
+
+ await brady.knows.replace(hanks)
+ assert len(await brady.knows.all()) == 1
+ assert await brady.knows.is_connected(hanks)
+ assert not await brady.knows.is_connected(gronk)
+ assert not await brady.knows.is_connected(colbert)
+
+
+@mark_async_test
+async def test_props_relationship():
+ u = await PersonWithRels(name="Mar", age=20).save()
+ assert u
+
+ c = await Country(code="AT").save()
+ assert c
+
+ c2 = await Country(code="LA").save()
+ assert c2
+
+ with raises(NotImplementedError):
+ await c.inhabitant.connect(u, properties={"city": "Thessaloniki"})
diff --git a/test/async_/test_relative_relationships.py b/test/async_/test_relative_relationships.py
new file mode 100644
index 00000000..7b283f84
--- /dev/null
+++ b/test/async_/test_relative_relationships.py
@@ -0,0 +1,24 @@
+from test._async_compat import mark_async_test
+from test.async_.test_relationships import Country
+
+from neomodel import AsyncRelationshipTo, AsyncStructuredNode, StringProperty
+
+
+class Cat(AsyncStructuredNode):
+ name = StringProperty()
+ # Relationship is defined using a relative class path
+ is_from = AsyncRelationshipTo(".test_relationships.Country", "IS_FROM")
+
+
+@mark_async_test
+async def test_relative_relationship():
+ a = await Cat(name="snufkin").save()
+ assert a
+
+ c = await Country(code="MG").save()
+ assert c
+
+ # connecting an instance of the class defined above
+ # the next statement will fail if there's a type mismatch
+ await a.is_from.connect(c)
+ assert await a.is_from.is_connected(c)
diff --git a/test/async_/test_transactions.py b/test/async_/test_transactions.py
new file mode 100644
index 00000000..59d523c5
--- /dev/null
+++ b/test/async_/test_transactions.py
@@ -0,0 +1,193 @@
+from test._async_compat import mark_async_test
+
+import pytest
+from neo4j.api import Bookmarks
+from neo4j.exceptions import ClientError, TransactionError
+from pytest import raises
+
+from neomodel import AsyncStructuredNode, StringProperty, UniqueProperty, adb
+
+
+class APerson(AsyncStructuredNode):
+ name = StringProperty(unique_index=True)
+
+
+@mark_async_test
+async def test_rollback_and_commit_transaction():
+ for p in await APerson.nodes:
+ await p.delete()
+
+ await APerson(name="Roger").save()
+
+ await adb.begin()
+ await APerson(name="Terry S").save()
+ await adb.rollback()
+
+ assert len(await APerson.nodes) == 1
+
+ await adb.begin()
+ await APerson(name="Terry S").save()
+ await adb.commit()
+
+ assert len(await APerson.nodes) == 2
+
+
+@adb.transaction
+async def in_a_tx(*names):
+ for n in names:
+ await APerson(name=n).save()
+
+
+@mark_async_test
+async def test_transaction_decorator():
+ await adb.install_labels(APerson)
+ for p in await APerson.nodes:
+ await p.delete()
+
+ # should work
+ await in_a_tx("Roger")
+
+ # should bail but raise correct error
+ with raises(UniqueProperty):
+ await in_a_tx("Jim", "Roger")
+
+ assert "Jim" not in [p.name for p in await APerson.nodes]
+
+
+@mark_async_test
+async def test_transaction_as_a_context():
+ async with adb.transaction:
+ await APerson(name="Tim").save()
+
+ assert await APerson.nodes.filter(name="Tim")
+
+ with raises(UniqueProperty):
+ async with adb.transaction:
+ await APerson(name="Tim").save()
+
+
+@mark_async_test
+async def test_query_inside_transaction():
+ for p in await APerson.nodes:
+ await p.delete()
+
+ async with adb.transaction:
+ await APerson(name="Alice").save()
+ await APerson(name="Bob").save()
+
+ assert len([p.name for p in await APerson.nodes]) == 2
+
+
+@mark_async_test
+async def test_read_transaction():
+ await APerson(name="Johnny").save()
+
+ async with adb.read_transaction:
+ people = await APerson.nodes
+ assert people
+
+ with raises(TransactionError):
+ async with adb.read_transaction:
+ with raises(ClientError) as e:
+ await APerson(name="Gina").save()
+ assert e.value.code == "Neo.ClientError.Statement.AccessMode"
+
+
+@mark_async_test
+async def test_write_transaction():
+ async with adb.write_transaction:
+ await APerson(name="Amelia").save()
+
+ amelia = await APerson.nodes.get(name="Amelia")
+ assert amelia
+
+
+@mark_async_test
+async def double_transaction():
+ await adb.begin()
+ with raises(SystemError, match=r"Transaction in progress"):
+ await adb.begin()
+
+ await adb.rollback()
+
+
+@adb.transaction.with_bookmark
+async def in_a_tx_with_bookmark(*names):
+ for n in names:
+ await APerson(name=n).save()
+
+
+@mark_async_test
+async def test_bookmark_transaction_decorator():
+ for p in await APerson.nodes:
+ await p.delete()
+
+ # should work
+ result, bookmarks = await in_a_tx_with_bookmark("Ruth", bookmarks=None)
+ assert result is None
+ assert isinstance(bookmarks, Bookmarks)
+
+ # should bail but raise correct error
+ with raises(UniqueProperty):
+ await in_a_tx_with_bookmark("Jane", "Ruth")
+
+ assert "Jane" not in [p.name for p in await APerson.nodes]
+
+
+@mark_async_test
+async def test_bookmark_transaction_as_a_context():
+ async with adb.transaction as transaction:
+ await APerson(name="Tanya").save()
+ assert isinstance(transaction.last_bookmark, Bookmarks)
+
+ assert await APerson.nodes.filter(name="Tanya")
+
+ with raises(UniqueProperty):
+ async with adb.transaction as transaction:
+ await APerson(name="Tanya").save()
+ assert not hasattr(transaction, "last_bookmark")
+
+
+@pytest.fixture
+def spy_on_db_begin(monkeypatch):
+ spy_calls = []
+ original_begin = adb.begin
+
+ def begin_spy(*args, **kwargs):
+ spy_calls.append((args, kwargs))
+ return original_begin(*args, **kwargs)
+
+ monkeypatch.setattr(adb, "begin", begin_spy)
+ return spy_calls
+
+
+@mark_async_test
+async def test_bookmark_passed_in_to_context(spy_on_db_begin):
+ transaction = adb.transaction
+ async with transaction:
+ pass
+
+ assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None})
+ last_bookmark = transaction.last_bookmark
+
+ transaction.bookmarks = last_bookmark
+ async with transaction:
+ pass
+ assert spy_on_db_begin[-1] == (
+ (),
+ {"access_mode": None, "bookmarks": last_bookmark},
+ )
+
+
+@mark_async_test
+async def test_query_inside_bookmark_transaction():
+ for p in await APerson.nodes:
+ await p.delete()
+
+ async with adb.transaction as transaction:
+ await APerson(name="Alice").save()
+ await APerson(name="Bob").save()
+
+ assert len([p.name for p in await APerson.nodes]) == 2
+
+ assert isinstance(transaction.last_bookmark, Bookmarks)
diff --git a/test/conftest.py b/test/conftest.py
index 7ec261d6..291eedf5 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,11 +1,9 @@
from __future__ import print_function
import os
-import warnings
import pytest
-from neomodel import clear_neo4j_database, config, db
from neomodel.util import version_tag_to_integer
NEO4J_URL = os.environ.get("NEO4J_URL", "bolt://localhost:7687")
@@ -30,64 +28,33 @@ def pytest_addoption(parser):
@pytest.hookimpl
def pytest_collection_modifyitems(items):
- connect_to_aura_items = []
- normal_items = []
+ async_items = []
+ sync_items = []
+ async_connect_to_aura_items = []
+ sync_connect_to_aura_items = []
- # Separate all tests into two groups: those with "connect_to_aura" in their name, and all others
for item in items:
+ # Check the directory of the item
+ directory = item.fspath.dirname.split("/")[-1]
+
if "connect_to_aura" in item.name:
- connect_to_aura_items.append(item)
+ if directory == "async_":
+ async_connect_to_aura_items.append(item)
+ elif directory == "sync_":
+ sync_connect_to_aura_items.append(item)
else:
- normal_items.append(item)
-
- # Add all normal tests back to the front of the list
- new_order = normal_items
-
- # Add all connect_to_aura tests to the end of the list
- new_order.extend(connect_to_aura_items)
-
- # Replace the original items list with the new order
- items[:] = new_order
-
-
-@pytest.hookimpl
-def pytest_sessionstart(session):
- """
- Provides initial connection to the database and sets up the rest of the test suite
-
- :param session: The session object. Please see `_
- :type Session object: For more information please see `_
- """
-
- warnings.simplefilter("default")
-
- config.DATABASE_URL = os.environ.get(
- "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687"
- )
- config.AUTO_INSTALL_LABELS = True
-
- # Clear the database if required
- database_is_populated, _ = db.cypher_query(
- "MATCH (a) return count(a)>0 as database_is_populated"
- )
- if database_is_populated[0][0] and not session.config.getoption("resetdb"):
- raise SystemError(
- "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb."
- )
- else:
- clear_neo4j_database(db, clear_constraints=True, clear_indexes=True)
-
- db.cypher_query(
- "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED"
+ if directory == "async_":
+ async_items.append(item)
+ elif directory == "sync_":
+ sync_items.append(item)
+
+ new_order = (
+ async_items
+ + async_connect_to_aura_items
+ + sync_items
+ + sync_connect_to_aura_items
)
- if db.database_edition == "enterprise":
- db.cypher_query("GRANT ROLE publisher TO troygreene")
- db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin")
-
-
-@pytest.hookimpl
-def pytest_unconfigure():
- db.close_connection()
+ items[:] = new_order
def check_and_skip_neo4j_least_version(required_least_neo4j_version, message):
@@ -112,8 +79,3 @@ def check_and_skip_neo4j_least_version(required_least_neo4j_version, message):
"Neo4j version: {}. {}."
"Skipping test.".format(os.environ["NEO4J_VERSION"], message)
)
-
-
-@pytest.fixture
-def skip_neo4j_before_330():
- check_and_skip_neo4j_least_version(330, "Neo4J version does not support this test")
diff --git a/test/sync_/__init__.py b/test/sync_/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/sync_/conftest.py b/test/sync_/conftest.py
new file mode 100644
index 00000000..d2cd787e
--- /dev/null
+++ b/test/sync_/conftest.py
@@ -0,0 +1,48 @@
+import os
+import warnings
+from test._async_compat import mark_sync_session_auto_fixture
+
+from neomodel import config, db
+
+
+@mark_sync_session_auto_fixture
+def setup_neo4j_session(request):
+ """
+ Provides initial connection to the database and sets up the rest of the test suite
+
+ :param request: The request object. Please see `_
+ :type Request object: For more information please see `_
+ """
+
+ warnings.simplefilter("default")
+
+ config.DATABASE_URL = os.environ.get(
+ "NEO4J_BOLT_URL", "bolt://neo4j:foobarbaz@localhost:7687"
+ )
+
+ # Clear the database if required
+ database_is_populated, _ = db.cypher_query(
+ "MATCH (a) return count(a)>0 as database_is_populated"
+ )
+ if database_is_populated[0][0] and not request.config.getoption("resetdb"):
+ raise SystemError(
+ "Please note: The database seems to be populated.\n\tEither delete all nodes and edges manually, or set the --resetdb parameter when calling pytest\n\n\tpytest --resetdb."
+ )
+
+ db.clear_neo4j_database(clear_constraints=True, clear_indexes=True)
+
+ db.install_all_labels()
+
+ db.cypher_query(
+ "CREATE OR REPLACE USER troygreene SET PASSWORD 'foobarbaz' CHANGE NOT REQUIRED"
+ )
+ db_edition = db.database_edition
+ if db_edition == "enterprise":
+ db.cypher_query("GRANT ROLE publisher TO troygreene")
+ db.cypher_query("GRANT IMPERSONATE (troygreene) ON DBMS TO admin")
+
+
+@mark_sync_session_auto_fixture
+def cleanup():
+ yield
+ db.close_connection()
diff --git a/test/test_alias.py b/test/sync_/test_alias.py
similarity index 83%
rename from test/test_alias.py
rename to test/sync_/test_alias.py
index c63119aa..420e62a0 100644
--- a/test/test_alias.py
+++ b/test/sync_/test_alias.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from neomodel import AliasProperty, StringProperty, StructuredNode
@@ -12,12 +14,14 @@ class AliasTestNode(StructuredNode):
long_name = MagicProperty(to="name")
+@mark_sync_test
def test_property_setup_hook():
- tim = AliasTestNode(long_name="tim").save()
+ timmy = AliasTestNode(long_name="timmy").save()
assert AliasTestNode.setup_hook_called
- assert tim.name == "tim"
+ assert timmy.name == "timmy"
+@mark_sync_test
def test_alias():
jim = AliasTestNode(full_name="Jim").save()
assert jim.name == "Jim"
diff --git a/test/test_batch.py b/test/sync_/test_batch.py
similarity index 79%
rename from test/test_batch.py
rename to test/sync_/test_batch.py
index c2d1ec86..80812d31 100644
--- a/test/test_batch.py
+++ b/test/sync_/test_batch.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from pytest import raises
from neomodel import (
@@ -9,6 +11,7 @@
UniqueIdProperty,
config,
)
+from neomodel._async_compat.util import Util
from neomodel.exceptions import DeflateError, UniqueProperty
config.AUTO_INSTALL_LABELS = True
@@ -20,6 +23,7 @@ class UniqueUser(StructuredNode):
age = IntegerProperty()
+@mark_sync_test
def test_unique_id_property_batch():
users = UniqueUser.create({"name": "bob", "age": 2}, {"name": "ben", "age": 3})
@@ -36,6 +40,7 @@ class Customer(StructuredNode):
age = IntegerProperty(index=True)
+@mark_sync_test
def test_batch_create():
users = Customer.create(
{"email": "jim1@aol.com", "age": 11},
@@ -51,6 +56,7 @@ def test_batch_create():
assert Customer.nodes.get(email="jim1@aol.com")
+@mark_sync_test
def test_batch_create_or_update():
users = Customer.create_or_update(
{"email": "merge1@aol.com", "age": 11},
@@ -60,7 +66,8 @@ def test_batch_create_or_update():
)
assert len(users) == 4
assert users[1] == users[3]
- assert Customer.nodes.get(email="merge1@aol.com").age == 11
+ merge_1: Customer = Customer.nodes.get(email="merge1@aol.com")
+ assert merge_1.age == 11
more_users = Customer.create_or_update(
{"email": "merge1@aol.com", "age": 22},
@@ -68,9 +75,11 @@ def test_batch_create_or_update():
)
assert len(more_users) == 2
assert users[0] == more_users[0]
- assert Customer.nodes.get(email="merge1@aol.com").age == 22
+ merge_1 = Customer.nodes.get(email="merge1@aol.com")
+ assert merge_1.age == 22
+@mark_sync_test
def test_batch_validation():
# test validation in batch create
with raises(DeflateError):
@@ -79,8 +88,9 @@ def test_batch_validation():
)
+@mark_sync_test
def test_batch_index_violation():
- for u in Customer.nodes.all():
+ for u in Customer.nodes:
u.delete()
users = Customer.create(
@@ -94,7 +104,10 @@ def test_batch_index_violation():
)
# not found
- assert not Customer.nodes.filter(email="jim7@aol.com")
+ if Util.is_async_code:
+ assert not Customer.nodes.filter(email="jim7@aol.com").__bool__()
+ else:
+ assert not Customer.nodes.filter(email="jim7@aol.com")
class Dog(StructuredNode):
@@ -107,11 +120,14 @@ class Person(StructuredNode):
pets = RelationshipFrom("Dog", "owner")
+@mark_sync_test
def test_get_or_create_with_rel():
- bob = Person.get_or_create({"name": "Bob"})[0]
+ create_bob = Person.get_or_create({"name": "Bob"})
+ bob = create_bob[0]
bobs_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=bob.pets)
- tim = Person.get_or_create({"name": "Tim"})[0]
+ create_tim = Person.get_or_create({"name": "Tim"})
+ tim = create_tim[0]
tims_gizmo = Dog.get_or_create({"name": "Gizmo"}, relationship=tim.pets)
# not the same gizmo
diff --git a/test/test_cardinality.py b/test/sync_/test_cardinality.py
similarity index 83%
rename from test/test_cardinality.py
rename to test/sync_/test_cardinality.py
index 3c850db0..9e83762c 100644
--- a/test/test_cardinality.py
+++ b/test/sync_/test_cardinality.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from pytest import raises
from neomodel import (
@@ -39,44 +41,53 @@ class ToothBrush(StructuredNode):
name = StringProperty()
+@mark_sync_test
def test_cardinality_zero_or_more():
m = Monkey(name="tim").save()
assert m.dryers.all() == []
- assert m.dryers.single() is None
+ single_dryer = m.driver.single()
+ assert single_dryer is None
h = HairDryer(version=1).save()
m.dryers.connect(h)
assert len(m.dryers.all()) == 1
- assert m.dryers.single().version == 1
+ single_dryer = m.dryers.single()
+ assert single_dryer.version == 1
m.dryers.disconnect(h)
assert m.dryers.all() == []
- assert m.dryers.single() is None
+ single_dryer = m.driver.single()
+ assert single_dryer is None
h2 = HairDryer(version=2).save()
m.dryers.connect(h)
m.dryers.connect(h2)
m.dryers.disconnect_all()
assert m.dryers.all() == []
- assert m.dryers.single() is None
+ single_dryer = m.driver.single()
+ assert single_dryer is None
+@mark_sync_test
def test_cardinality_zero_or_one():
m = Monkey(name="bob").save()
assert m.driver.all() == []
+ single_driver = m.driver.single()
assert m.driver.single() is None
h = ScrewDriver(version=1).save()
m.driver.connect(h)
assert len(m.driver.all()) == 1
- assert m.driver.single().version == 1
+ single_driver = m.driver.single()
+ assert single_driver.version == 1
j = ScrewDriver(version=2).save()
with raises(AttemptedCardinalityViolation):
m.driver.connect(j)
m.driver.reconnect(h, j)
- assert m.driver.single().version == 2
+ single_driver = m.driver.single()
+ assert single_driver.version == 2
# Forcing creation of a second ToothBrush to go around
# AttemptedCardinalityViolation
@@ -94,6 +105,7 @@ def test_cardinality_zero_or_one():
m.driver.all()
+@mark_sync_test
def test_cardinality_one_or_more():
m = Monkey(name="jerry").save()
@@ -105,7 +117,8 @@ def test_cardinality_one_or_more():
c = Car(version=2).save()
m.car.connect(c)
- assert m.car.single().version == 2
+ single_car = m.car.single()
+ assert single_car.version == 2
cars = m.car.all()
assert len(cars) == 1
@@ -123,6 +136,7 @@ def test_cardinality_one_or_more():
assert len(cars) == 1
+@mark_sync_test
def test_cardinality_one():
m = Monkey(name="jerry").save()
@@ -136,9 +150,10 @@ def test_cardinality_one():
b = ToothBrush(name="Jim").save()
m.toothbrush.connect(b)
- assert m.toothbrush.single().name == "Jim"
+ single_toothbrush = m.toothbrush.single()
+ assert single_toothbrush.name == "Jim"
- x = ToothBrush(name="Jim").save
+ x = ToothBrush(name="Jim").save()
with raises(AttemptedCardinalityViolation):
m.toothbrush.connect(x)
diff --git a/test/test_connection.py b/test/sync_/test_connection.py
similarity index 87%
rename from test/test_connection.py
rename to test/sync_/test_connection.py
index fb3524bb..e7c0d7ce 100644
--- a/test/test_connection.py
+++ b/test/sync_/test_connection.py
@@ -1,15 +1,15 @@
import os
-import time
+from test._async_compat import mark_sync_test
+from test.conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME
import pytest
-from neo4j import GraphDatabase
+from neo4j import Driver, GraphDatabase
from neo4j.debug import watch
from neomodel import StringProperty, StructuredNode, config, db
-from .conftest import NEO4J_PASSWORD, NEO4J_URL, NEO4J_USERNAME
-
+@mark_sync_test
@pytest.fixture(autouse=True)
def setup_teardown():
yield
@@ -25,6 +25,7 @@ def neo4j_logging():
yield
+@mark_sync_test
def get_current_database_name() -> str:
"""
Fetches the name of the currently active database from the Neo4j database.
@@ -42,6 +43,7 @@ class Pastry(StructuredNode):
name = StringProperty(unique_index=True)
+@mark_sync_test
def test_set_connection_driver_works():
# Verify that current connection is up
assert Pastry(name="Chocolatine").save()
@@ -54,13 +56,16 @@ def test_set_connection_driver_works():
assert Pastry(name="Croissant").save()
+@mark_sync_test
def test_config_driver_works():
# Verify that current connection is up
assert Pastry(name="Chausson aux pommes").save()
db.close_connection()
# Test connection using a driver defined in config
- driver = GraphDatabase().driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
+ driver: Driver = GraphDatabase().driver(
+ NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
+ )
config.DRIVER = driver
assert Pastry(name="Grignette").save()
@@ -70,11 +75,10 @@ def test_config_driver_works():
config.DRIVER = None
-@pytest.mark.skipif(
- db.database_edition != "enterprise",
- reason="Skipping test for community edition - no multi database in CE",
-)
+@mark_sync_test
def test_connect_to_non_default_database():
+ if not db.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition - no multi database in CE")
database_name = "pastries"
db.cypher_query(f"CREATE DATABASE {database_name} IF NOT EXISTS")
db.close_connection()
@@ -105,6 +109,7 @@ def test_connect_to_non_default_database():
config.DATABASE_NAME = None
+@mark_sync_test
@pytest.mark.parametrize(
"url", ["bolt://user:password", "http://user:password@localhost:7687"]
)
@@ -116,6 +121,7 @@ def test_wrong_url_format(url):
db.set_connection(url=url)
+@mark_sync_test
@pytest.mark.parametrize("protocol", ["neo4j+s", "neo4j+ssc", "bolt+s", "bolt+ssc"])
def test_connect_to_aura(protocol):
cypher_return = "hello world"
diff --git a/test/sync_/test_contrib/__init__.py b/test/sync_/test_contrib/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/test_contrib/test_semi_structured.py b/test/sync_/test_contrib/test_semi_structured.py
similarity index 87%
rename from test/test_contrib/test_semi_structured.py
rename to test/sync_/test_contrib/test_semi_structured.py
index fe73a2bd..f4b9746b 100644
--- a/test/test_contrib/test_semi_structured.py
+++ b/test/sync_/test_contrib/test_semi_structured.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from neomodel import IntegerProperty, StringProperty
from neomodel.contrib import SemiStructuredNode
@@ -11,11 +13,13 @@ class Dummy(SemiStructuredNode):
pass
+@mark_sync_test
def test_to_save_to_model_with_required_only():
u = UserProf(email="dummy@test.com")
assert u.save()
+@mark_sync_test
def test_save_to_model_with_extras():
u = UserProf(email="jim@test.com", age=3, bar=99)
u.foo = True
@@ -25,6 +29,7 @@ def test_save_to_model_with_extras():
assert u.bar == 99
+@mark_sync_test
def test_save_empty_model():
dummy = Dummy()
assert dummy.save()
diff --git a/test/sync_/test_contrib/test_spatial_datatypes.py b/test/sync_/test_contrib/test_spatial_datatypes.py
new file mode 100644
index 00000000..b35ced3a
--- /dev/null
+++ b/test/sync_/test_contrib/test_spatial_datatypes.py
@@ -0,0 +1,397 @@
+"""
+Provides a test case for data types required by issue 374 - "Support for Point property type".
+
+At the moment, only one new datatype is offered: NeomodelPoint
+
+For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374
+"""
+
+import os
+
+import pytest
+import shapely
+
+import neomodel
+import neomodel.contrib.spatial_properties
+from neomodel.util import version_tag_to_integer
+
+
+def check_and_skip_neo4j_least_version(required_least_neo4j_version, message):
+ """
+ Checks if the NEO4J_VERSION is at least `required_least_neo4j_version` and skips a test if not.
+
+ WARNING: If the NEO4J_VERSION variable is not set, this function returns True, allowing the test to go ahead.
+
+ :param required_least_neo4j_version: The least version to check. This must be the numberic representation of the
+ version. That is: '3.4.0' would be passed as 340.
+ :type required_least_neo4j_version: int
+ :param message: An informative message as to why the calling test had to be skipped.
+ :type message: str
+ :return: A boolean value of True if the version reported is at least `required_least_neo4j_version`
+ """
+ if "NEO4J_VERSION" in os.environ:
+ if (
+ version_tag_to_integer(os.environ["NEO4J_VERSION"])
+ < required_least_neo4j_version
+ ):
+ pytest.skip(
+ "Neo4j version: {}. {}."
+ "Skipping test.".format(os.environ["NEO4J_VERSION"], message)
+ )
+
+
+def basic_type_assertions(
+ ground_truth, tested_object, test_description, check_neo4j_points=False
+):
+ """
+ Tests that `tested_object` has been created as intended.
+
+ :param ground_truth: The object as it is supposed to have been created.
+ :type ground_truth: NeomodelPoint or neo4j.v1.spatial.Point
+ :param tested_object: The object as it results from one of the contructors.
+ :type tested_object: NeomodelPoint or neo4j.v1.spatial.Point
+ :param test_description: A brief description of the test being performed.
+ :type test_description: str
+ :param check_neo4j_points: Whether to assert between NeomodelPoint or neo4j.v1.spatial.Point objects.
+ :type check_neo4j_points: bool
+ :return:
+ """
+ if check_neo4j_points:
+ assert isinstance(
+ tested_object, type(ground_truth)
+ ), "{} did not return Neo4j Point".format(test_description)
+ assert (
+ tested_object.srid == ground_truth.srid
+ ), "{} does not have the expected SRID({})".format(
+ test_description, ground_truth.srid
+ )
+ assert len(tested_object) == len(
+ ground_truth
+ ), "Dimensionality mismatch. Expected {}, had {}".format(
+ len(ground_truth.coords), len(tested_object.coords)
+ )
+ else:
+ assert isinstance(
+ tested_object, type(ground_truth)
+ ), "{} did not return NeomodelPoint".format(test_description)
+ assert (
+ tested_object.crs == ground_truth.crs
+ ), "{} does not have the expected CRS({})".format(
+ test_description, ground_truth.crs
+ )
+ assert len(tested_object.coords[0]) == len(
+ ground_truth.coords[0]
+ ), "Dimensionality mismatch. Expected {}, had {}".format(
+ len(ground_truth.coords[0]), len(tested_object.coords[0])
+ )
+
+
+# Object Construction
+def test_coord_constructor():
+ """
+ Tests all the possible ways by which a NeomodelPoint can be instantiated successfully via passing coordinates.
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Implicit cartesian point with coords
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0))
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0))
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Implicit 2d cartesian point instantiation",
+ )
+
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0)
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0, 0.0))
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Implicit 3d cartesian point instantiation",
+ )
+
+ # Explicit geographical point with coords
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0), crs="wgs-84"
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0), crs="wgs-84"
+ )
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Explicit 2d geographical point with tuple of coords instantiation",
+ )
+
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="wgs-84-3d"
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="wgs-84-3d"
+ )
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Explicit 3d geographical point with tuple of coords instantiation",
+ )
+
+ # Cartesian point with named arguments
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ x=0.0, y=0.0
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(x=0.0, y=0.0)
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Cartesian 2d point with named arguments",
+ )
+
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ x=0.0, y=0.0, z=0.0
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(x=0.0, y=0.0, z=0.0)
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Cartesian 3d point with named arguments",
+ )
+
+ # Geographical point with named arguments
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ longitude=0.0, latitude=0.0
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ longitude=0.0, latitude=0.0
+ )
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Geographical 2d point with named arguments",
+ )
+
+ ground_truth_object = neomodel.contrib.spatial_properties.NeomodelPoint(
+ longitude=0.0, latitude=0.0, height=0.0
+ )
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ longitude=0.0, latitude=0.0, height=0.0
+ )
+ basic_type_assertions(
+ ground_truth_object,
+ new_point,
+ "Geographical 3d point with named arguments",
+ )
+
+
+def test_copy_constructors():
+ """
+ Tests all the possible ways by which a NeomodelPoint can be instantiated successfully via a copy constructor call.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Instantiate from Shapely point
+
+ # Implicit cartesian from shapely point
+ ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0), crs="cartesian"
+ )
+ shapely_point = shapely.geometry.Point((0.0, 0.0))
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(shapely_point)
+ basic_type_assertions(
+ ground_truth, new_point, "Implicit cartesian by shapely Point"
+ )
+
+ # Explicit geographical by shapely point
+ ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="wgs-84-3d"
+ )
+ shapely_point = shapely.geometry.Point((0.0, 0.0, 0.0))
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ shapely_point, crs="wgs-84-3d"
+ )
+ basic_type_assertions(
+ ground_truth, new_point, "Explicit geographical by shapely Point"
+ )
+
+ # Copy constructor for NeomodelPoints
+ ground_truth = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0))
+ other_neomodel_point = neomodel.contrib.spatial_properties.NeomodelPoint((0.0, 0.0))
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(other_neomodel_point)
+ basic_type_assertions(ground_truth, new_point, "NeomodelPoint copy constructor")
+
+
+def test_prohibited_constructor_forms():
+ """
+ Tests all the possible forms by which construction of NeomodelPoints should fail.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Absurd CRS
+ with pytest.raises(ValueError, match=r"Invalid CRS\(blue_hotel\)"):
+ _ = neomodel.contrib.spatial_properties.NeomodelPoint((0, 0), crs="blue_hotel")
+
+ # Absurd coord dimensionality
+ with pytest.raises(
+ ValueError,
+ ):
+ _ = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0, 0, 0, 0, 0, 0, 0), crs="cartesian"
+ )
+
+ # Absurd datatype passed to copy constructor
+ with pytest.raises(
+ TypeError,
+ ):
+ _ = neomodel.contrib.spatial_properties.NeomodelPoint(
+ "it don't mean a thing if it ain't got that swing",
+ crs="cartesian",
+ )
+
+ # Trying to instantiate a point with any of BOTH x,y,z or longitude, latitude, height
+ with pytest.raises(ValueError, match="Invalid instantiation via arguments"):
+ _ = neomodel.contrib.spatial_properties.NeomodelPoint(
+ x=0.0,
+ y=0.0,
+ longitude=0.0,
+ latitude=2.0,
+ height=-2.0,
+ crs="cartesian",
+ )
+
+ # Trying to instantiate a point with absolutely NO parameters
+ with pytest.raises(ValueError, match="Invalid instantiation via no arguments"):
+ _ = neomodel.contrib.spatial_properties.NeomodelPoint()
+
+
+def test_property_accessors_depending_on_crs_shapely_lt_2():
+ """
+ Tests that points are accessed via their respective accessors.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Check the version of Shapely installed to run the appropriate tests:
+ try:
+ from shapely import __version__
+ except ImportError:
+ pytest.skip("Shapely not installed")
+
+ if int("".join(__version__.split(".")[0:3])) >= 200:
+ pytest.skip("Shapely 2 is installed, skipping earlier version test")
+
+ # Geometrical points only have x,y,z coordinates
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="cartesian-3d"
+ )
+ with pytest.raises(AttributeError, match=r'Invalid coordinate \("longitude"\)'):
+ new_point.longitude
+ with pytest.raises(AttributeError, match=r'Invalid coordinate \("latitude"\)'):
+ new_point.latitude
+ with pytest.raises(AttributeError, match=r'Invalid coordinate \("height"\)'):
+ new_point.height
+
+ # Geographical points only have longitude, latitude, height coordinates
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="wgs-84-3d"
+ )
+ with pytest.raises(AttributeError, match=r'Invalid coordinate \("x"\)'):
+ new_point.x
+ with pytest.raises(AttributeError, match=r'Invalid coordinate \("y"\)'):
+ new_point.y
+ with pytest.raises(AttributeError, match=r'Invalid coordinate \("z"\)'):
+ new_point.z
+
+
+def test_property_accessors_depending_on_crs_shapely_gte_2():
+ """
+ Tests that points are accessed via their respective accessors.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Check the version of Shapely installed to run the appropriate tests:
+ try:
+ from shapely import __version__
+ except ImportError:
+ pytest.skip("Shapely not installed")
+
+ if int("".join(__version__.split(".")[0:3])) < 200:
+ pytest.skip("Shapely < 2.0.0 is installed, skipping test")
+ # Geometrical points only have x,y,z coordinates
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="cartesian-3d"
+ )
+ with pytest.raises(TypeError, match=r'Invalid coordinate \("longitude"\)'):
+ new_point.longitude
+ with pytest.raises(TypeError, match=r'Invalid coordinate \("latitude"\)'):
+ new_point.latitude
+ with pytest.raises(TypeError, match=r'Invalid coordinate \("height"\)'):
+ new_point.height
+
+ # Geographical points only have longitude, latitude, height coordinates
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 0.0, 0.0), crs="wgs-84-3d"
+ )
+ with pytest.raises(TypeError, match=r'Invalid coordinate \("x"\)'):
+ new_point.x
+ with pytest.raises(TypeError, match=r'Invalid coordinate \("y"\)'):
+ new_point.y
+ with pytest.raises(TypeError, match=r'Invalid coordinate \("z"\)'):
+ new_point.z
+
+
+def test_property_accessors():
+ """
+ Tests that points are accessed via their respective accessors and that these accessors return the right values.
+
+ :return:
+ """
+
+ # Neo4j versions lower than 3.4.0 do not support Point. In that case, skip the test.
+ check_and_skip_neo4j_least_version(
+ 340, "This version does not support spatial data types."
+ )
+
+ # Geometrical points
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 1.0, 2.0), crs="cartesian-3d"
+ )
+ assert new_point.x == 0.0, "Expected x coordinate to be 0.0"
+ assert new_point.y == 1.0, "Expected y coordinate to be 1.0"
+ assert new_point.z == 2.0, "Expected z coordinate to be 2.0"
+
+ # Geographical points
+ new_point = neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0.0, 1.0, 2.0), crs="wgs-84-3d"
+ )
+ assert new_point.longitude == 0.0, "Expected longitude to be 0.0"
+ assert new_point.latitude == 1.0, "Expected latitude to be 1.0"
+ assert new_point.height == 2.0, "Expected height to be 2.0"
diff --git a/test/test_contrib/test_spatial_properties.py b/test/sync_/test_contrib/test_spatial_properties.py
similarity index 96%
rename from test/test_contrib/test_spatial_properties.py
rename to test/sync_/test_contrib/test_spatial_properties.py
index 03177c02..f33f4fb6 100644
--- a/test/test_contrib/test_spatial_properties.py
+++ b/test/sync_/test_contrib/test_spatial_properties.py
@@ -4,8 +4,8 @@
For more information please see: https://github.com/neo4j-contrib/neomodel/issues/374
"""
-import os
import random
+from test._async_compat import mark_sync_test
import neo4j.spatial
import pytest
@@ -156,6 +156,7 @@ def test_deflate():
)
+@mark_sync_test
def test_default_value():
"""
Tests that the default value passing mechanism works as expected with NeomodelPoint values.
@@ -194,6 +195,7 @@ class LocalisableEntity(neomodel.StructuredNode):
), ("Default value assignment failed.")
+@mark_sync_test
def test_array_of_points():
"""
Tests that Arrays of Points work as expected.
@@ -236,6 +238,7 @@ class AnotherLocalisableEntity(neomodel.StructuredNode):
], "Array of Points incorrect values."
+@mark_sync_test
def test_simple_storage_retrieval():
"""
Performs a simple Create, Retrieve via .save(), .get() which, due to the way Q objects operate, tests the
@@ -264,6 +267,7 @@ class TestStorageRetrievalProperty(neomodel.StructuredNode):
assert a_restaurant.description == a_property.description
+
def test_equality_with_other_objects():
"""
Performs equality tests and ensures tha ``NeomodelPoint`` can be compared with ShapelyPoint and NeomodelPoint only.
@@ -277,6 +281,9 @@ def test_equality_with_other_objects():
if int("".join(__version__.split(".")[0:3])) < 200:
pytest.skip(f"Shapely 2.0 not present (Current version is {__version__}")
- assert neomodel.contrib.spatial_properties.NeomodelPoint((0,0)) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0)
- assert neomodel.contrib.spatial_properties.NeomodelPoint((0,0)) == shapely.geometry.Point((0,0))
-
+ assert neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0, 0)
+ ) == neomodel.contrib.spatial_properties.NeomodelPoint(x=0, y=0)
+ assert neomodel.contrib.spatial_properties.NeomodelPoint(
+ (0, 0)
+ ) == shapely.geometry.Point((0, 0))
diff --git a/test/test_cypher.py b/test/sync_/test_cypher.py
similarity index 82%
rename from test/test_cypher.py
rename to test/sync_/test_cypher.py
index 7c2e6fd6..944959c4 100644
--- a/test/test_cypher.py
+++ b/test/sync_/test_cypher.py
@@ -1,12 +1,13 @@
import builtins
+from test._async_compat import mark_sync_test
import pytest
from neo4j.exceptions import ClientError as CypherError
from numpy import ndarray
from pandas import DataFrame, Series
-from neomodel import StringProperty, StructuredNode
-from neomodel.core import db
+from neomodel import StringProperty, StructuredNode, db
+from neomodel._async_compat.util import Util
class User2(StructuredNode):
@@ -36,6 +37,7 @@ def mocked_import(name, *args, **kwargs):
monkeypatch.setattr(builtins, "__import__", mocked_import)
+@mark_sync_test
def test_cypher():
"""
test result format is backward compatible with earlier versions of neomodel
@@ -58,6 +60,7 @@ def test_cypher():
assert "a" in meta and "b" in meta
+@mark_sync_test
def test_cypher_syntax_error():
jim = User2(email="jim1@test.com").save()
try:
@@ -69,8 +72,13 @@ def test_cypher_syntax_error():
assert False, "CypherError not raised."
+@mark_sync_test
@pytest.mark.parametrize("hide_available_pkg", ["pandas"], indirect=True)
def test_pandas_not_installed(hide_available_pkg):
+ # We run only the async version, because this fails on second run
+ # because import error is thrown only when pandas.py is imported
+ if not Util.is_async_code:
+ pytest.skip("This test is only")
with pytest.raises(ImportError):
with pytest.warns(
UserWarning,
@@ -81,6 +89,7 @@ def test_pandas_not_installed(hide_available_pkg):
_ = to_dataframe(db.cypher_query("MATCH (a) RETURN a.name AS name"))
+@mark_sync_test
def test_pandas_integration():
from neomodel.integration.pandas import to_dataframe, to_series
@@ -113,18 +122,24 @@ def test_pandas_integration():
assert df["name"].tolist() == ["jimla", "jimlo"]
+@mark_sync_test
@pytest.mark.parametrize("hide_available_pkg", ["numpy"], indirect=True)
def test_numpy_not_installed(hide_available_pkg):
+ # We run only the async version, because this fails on second run
+ # because import error is thrown only when numpy.py is imported
+ if not Util.is_async_code:
+ pytest.skip("This test is only")
with pytest.raises(ImportError):
with pytest.warns(
UserWarning,
- match="The neomodel.integration.numpy module expects pandas to be installed",
+ match="The neomodel.integration.numpy module expects numpy to be installed",
):
from neomodel.integration.numpy import to_ndarray
_ = to_ndarray(db.cypher_query("MATCH (a) RETURN a.name AS name"))
+@mark_sync_test
def test_numpy_integration():
from neomodel.integration.numpy import to_ndarray
@@ -132,9 +147,11 @@ def test_numpy_integration():
jimlu = UserNP(email="jimlu@test.com", name="jimlu").save()
array = to_ndarray(
- db.cypher_query("MATCH (a:UserNP) RETURN a.name AS name, a.email AS email")
+ db.cypher_query(
+ "MATCH (a:UserNP) RETURN a.name AS name, a.email AS email ORDER BY name"
+ )
)
assert isinstance(array, ndarray)
assert array.shape == (2, 2)
- assert array[0][0] == "jimly"
+ assert array[0][0] == "jimlu"
diff --git a/test/test_database_management.py b/test/sync_/test_database_management.py
similarity index 84%
rename from test/test_database_management.py
rename to test/sync_/test_database_management.py
index 2a2ece34..9f663994 100644
--- a/test/test_database_management.py
+++ b/test/sync_/test_database_management.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
import pytest
from neo4j.exceptions import AuthError
@@ -8,7 +10,6 @@
StructuredNode,
StructuredRel,
db,
- util,
)
@@ -26,26 +27,28 @@ class Venue(StructuredNode):
in_city = RelationshipTo(City, relation_type="IN", model=InCity)
+@mark_sync_test
def test_clear_database():
venue = Venue(name="Royal Albert Hall", creator="Queen Victoria").save()
city = City(name="London").save()
venue.in_city.connect(city)
# Clear only the data
- util.clear_neo4j_database(db)
+ db.clear_neo4j_database()
database_is_populated, _ = db.cypher_query(
"MATCH (a) return count(a)>0 as database_is_populated"
)
assert database_is_populated[0][0] is False
+ db.install_all_labels()
indexes = db.list_indexes(exclude_token_lookup=True)
constraints = db.list_constraints()
assert len(indexes) > 0
assert len(constraints) > 0
# Clear constraints and indexes too
- util.clear_neo4j_database(db, clear_constraints=True, clear_indexes=True)
+ db.clear_neo4j_database(clear_constraints=True, clear_indexes=True)
indexes = db.list_indexes(exclude_token_lookup=True)
constraints = db.list_constraints()
@@ -53,13 +56,14 @@ def test_clear_database():
assert len(constraints) == 0
+@mark_sync_test
def test_change_password():
prev_password = "foobarbaz"
new_password = "newpassword"
prev_url = f"bolt://neo4j:{prev_password}@localhost:7687"
new_url = f"bolt://neo4j:{new_password}@localhost:7687"
- util.change_neo4j_password(db, "neo4j", new_password)
+ db.change_neo4j_password("neo4j", new_password)
db.close_connection()
db.set_connection(url=new_url)
@@ -71,7 +75,7 @@ def test_change_password():
db.close_connection()
db.set_connection(url=new_url)
- util.change_neo4j_password(db, "neo4j", prev_password)
+ db.change_neo4j_password("neo4j", prev_password)
db.close_connection()
db.set_connection(url=prev_url)
diff --git a/test/test_dbms_awareness.py b/test/sync_/test_dbms_awareness.py
similarity index 69%
rename from test/test_dbms_awareness.py
rename to test/sync_/test_dbms_awareness.py
index 93dc032d..f0f7fb68 100644
--- a/test/test_dbms_awareness.py
+++ b/test/sync_/test_dbms_awareness.py
@@ -1,14 +1,17 @@
-from pytest import mark
+from test._async_compat import mark_sync_test
+
+import pytest
from neomodel import db
from neomodel.util import version_tag_to_integer
-@mark.skipif(
- db.database_version != "5.7.0", reason="Testing a specific database version"
-)
+@mark_sync_test
def test_version_awareness():
- assert db.database_version == "5.7.0"
+ db_version = db.database_version
+ if db_version != "5.7.0":
+ pytest.skip("Testing a specific database version")
+ assert db_version == "5.7.0"
assert db.version_is_higher_than("5.7")
assert db.version_is_higher_than("5.6.0")
assert db.version_is_higher_than("5")
@@ -17,8 +20,10 @@ def test_version_awareness():
assert not db.version_is_higher_than("5.8")
+@mark_sync_test
def test_edition_awareness():
- if db.database_edition == "enterprise":
+ db_edition = db.database_edition
+ if db_edition == "enterprise":
assert db.edition_is_enterprise()
else:
assert not db.edition_is_enterprise()
diff --git a/test/test_driver_options.py b/test/sync_/test_driver_options.py
similarity index 69%
rename from test/test_driver_options.py
rename to test/sync_/test_driver_options.py
index 26f16640..f244d174 100644
--- a/test/test_driver_options.py
+++ b/test/sync_/test_driver_options.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
import pytest
from neo4j.exceptions import ClientError
from pytest import raises
@@ -6,28 +8,28 @@
from neomodel.exceptions import FeatureNotSupported
-@pytest.mark.skipif(
- not db.edition_is_enterprise(), reason="Skipping test for community edition"
-)
+@mark_sync_test
def test_impersonate():
+ if not db.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
with db.impersonate(user="troygreene"):
results, _ = db.cypher_query("RETURN 'Doo Wacko !'")
assert results[0][0] == "Doo Wacko !"
-@pytest.mark.skipif(
- not db.edition_is_enterprise(), reason="Skipping test for community edition"
-)
+@mark_sync_test
def test_impersonate_unauthorized():
+ if not db.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
with db.impersonate(user="unknownuser"):
with raises(ClientError):
_ = db.cypher_query("RETURN 'Gabagool'")
-@pytest.mark.skipif(
- not db.edition_is_enterprise(), reason="Skipping test for community edition"
-)
+@mark_sync_test
def test_impersonate_multiple_transactions():
+ if not db.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
with db.impersonate(user="troygreene"):
with db.transaction:
results, _ = db.cypher_query("RETURN 'Doo Wacko !'")
@@ -41,10 +43,10 @@ def test_impersonate_multiple_transactions():
assert results[0][0] == "neo4j"
-@pytest.mark.skipif(
- db.edition_is_enterprise(), reason="Skipping test for enterprise edition"
-)
+@mark_sync_test
def test_impersonate_community():
+ if db.edition_is_enterprise():
+ pytest.skip("Skipping test for enterprise edition")
with raises(FeatureNotSupported):
with db.impersonate(user="troygreene"):
_ = db.cypher_query("RETURN 'Gabagoogoo'")
diff --git a/test/test_exceptions.py b/test/sync_/test_exceptions.py
similarity index 93%
rename from test/test_exceptions.py
rename to test/sync_/test_exceptions.py
index 546c13fe..fe8cfe36 100644
--- a/test/test_exceptions.py
+++ b/test/sync_/test_exceptions.py
@@ -1,4 +1,5 @@
import pickle
+from test._async_compat import mark_sync_test
from neomodel import DoesNotExist, StringProperty, StructuredNode
@@ -7,6 +8,7 @@ class EPerson(StructuredNode):
name = StringProperty(unique_index=True)
+@mark_sync_test
def test_object_does_not_exist():
try:
EPerson.nodes.get(name="johnny")
diff --git a/test/test_hooks.py b/test/sync_/test_hooks.py
similarity index 92%
rename from test/test_hooks.py
rename to test/sync_/test_hooks.py
index 158db079..a6f742e0 100644
--- a/test/test_hooks.py
+++ b/test/sync_/test_hooks.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from neomodel import StringProperty, StructuredNode
HOOKS_CALLED = {}
@@ -22,6 +24,7 @@ def post_delete(self):
HOOKS_CALLED["post_delete"] = 1
+@mark_sync_test
def test_hooks():
ht = HookTest(name="k").save()
ht.delete()
diff --git a/test/test_indexing.py b/test/sync_/test_indexing.py
similarity index 83%
rename from test/test_indexing.py
rename to test/sync_/test_indexing.py
index 0b5e8fba..f39c22ef 100644
--- a/test/test_indexing.py
+++ b/test/sync_/test_indexing.py
@@ -1,14 +1,9 @@
+from test._async_compat import mark_sync_test
+
import pytest
from pytest import raises
-from neomodel import (
- IntegerProperty,
- StringProperty,
- StructuredNode,
- UniqueProperty,
- install_labels,
-)
-from neomodel.core import db
+from neomodel import IntegerProperty, StringProperty, StructuredNode, UniqueProperty, db
from neomodel.exceptions import ConstraintValidationFailed
@@ -17,8 +12,9 @@ class Human(StructuredNode):
age = IntegerProperty(index=True)
+@mark_sync_test
def test_unique_error():
- install_labels(Human)
+ db.install_labels(Human)
Human(name="j1m", age=13).save()
try:
Human(name="j1m", age=14).save()
@@ -29,10 +25,10 @@ def test_unique_error():
assert False, "UniqueProperty not raised."
-@pytest.mark.skipif(
- not db.edition_is_enterprise(), reason="Skipping test for community edition"
-)
+@mark_sync_test
def test_existence_constraint_error():
+ if not db.edition_is_enterprise():
+ pytest.skip("Skipping test for community edition")
db.cypher_query(
"CREATE CONSTRAINT test_existence_constraint FOR (n:Human) REQUIRE n.age IS NOT NULL"
)
@@ -42,6 +38,7 @@ def test_existence_constraint_error():
db.cypher_query("DROP CONSTRAINT test_existence_constraint")
+@mark_sync_test
def test_optional_properties_dont_get_indexed():
Human(name="99", age=99).save()
h = Human.nodes.get(age=99)
@@ -54,19 +51,21 @@ def test_optional_properties_dont_get_indexed():
assert h.name == "98"
+@mark_sync_test
def test_escaped_chars():
_name = "sarah:test"
Human(name=_name, age=3).save()
r = Human.nodes.filter(name=_name)
- assert r
assert r[0].name == _name
+@mark_sync_test
def test_does_not_exist():
with raises(Human.DoesNotExist):
Human.nodes.get(name="XXXX")
+@mark_sync_test
def test_custom_label_name():
class Giraffe(StructuredNode):
__label__ = "Giraffes"
diff --git a/test/test_issue112.py b/test/sync_/test_issue112.py
similarity index 82%
rename from test/test_issue112.py
rename to test/sync_/test_issue112.py
index 3e379932..f580f146 100644
--- a/test/test_issue112.py
+++ b/test/sync_/test_issue112.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from neomodel import RelationshipTo, StructuredNode
@@ -5,6 +7,7 @@ class SomeModel(StructuredNode):
test = RelationshipTo("SomeModel", "SELF")
+@mark_sync_test
def test_len_relationship():
t1 = SomeModel().save()
t2 = SomeModel().save()
diff --git a/test/test_issue283.py b/test/sync_/test_issue283.py
similarity index 67%
rename from test/test_issue283.py
rename to test/sync_/test_issue283.py
index b7a63f8f..a059f7f2 100644
--- a/test/test_issue283.py
+++ b/test/sync_/test_issue283.py
@@ -9,14 +9,24 @@
idea remains the same: "Instantiate the correct type of node at the end of
a relationship as specified by the model"
"""
-
-import datetime
-import os
import random
+from test._async_compat import mark_sync_test
import pytest
-import neomodel
+from neomodel import (
+ DateTimeProperty,
+ FloatProperty,
+ RelationshipClassNotDefined,
+ RelationshipClassRedefined,
+ RelationshipTo,
+ StringProperty,
+ StructuredNode,
+ StructuredRel,
+ UniqueIdProperty,
+ db,
+)
+from neomodel.exceptions import NodeClassAlreadyDefined, NodeClassNotDefined
try:
basestring
@@ -25,7 +35,7 @@
# Set up a very simple model for the tests
-class PersonalRelationship(neomodel.StructuredRel):
+class PersonalRelationship(StructuredRel):
"""
A very simple relationship between two basePersons that simply records
the date at which an acquaintance was established.
@@ -33,16 +43,16 @@ class PersonalRelationship(neomodel.StructuredRel):
basePerson without any further effort.
"""
- on_date = neomodel.DateTimeProperty(default_now=True)
+ on_date = DateTimeProperty(default_now=True)
-class BasePerson(neomodel.StructuredNode):
+class BasePerson(StructuredNode):
"""
Base class for defining some basic sort of an actor.
"""
- name = neomodel.StringProperty(required=True, unique_index=True)
- friends_with = neomodel.RelationshipTo(
+ name = StringProperty(required=True, unique_index=True)
+ friends_with = RelationshipTo(
"BasePerson", "FRIENDS_WITH", model=PersonalRelationship
)
@@ -52,7 +62,7 @@ class TechnicalPerson(BasePerson):
A Technical person specialises BasePerson by adding their expertise.
"""
- expertise = neomodel.StringProperty(required=True)
+ expertise = StringProperty(required=True)
class PilotPerson(BasePerson):
@@ -61,15 +71,15 @@ class PilotPerson(BasePerson):
can operate.
"""
- airplane = neomodel.StringProperty(required=True)
+ airplane = StringProperty(required=True)
-class BaseOtherPerson(neomodel.StructuredNode):
+class BaseOtherPerson(StructuredNode):
"""
An obviously "wrong" class of actor to befriend BasePersons with.
"""
- car_color = neomodel.StringProperty(required=True)
+ car_color = StringProperty(required=True)
class SomePerson(BaseOtherPerson):
@@ -81,6 +91,7 @@ class SomePerson(BaseOtherPerson):
# Test cases
+@mark_sync_test
def test_automatic_result_resolution():
"""
Node objects at the end of relationships are instantiated to their
@@ -88,30 +99,29 @@ def test_automatic_result_resolution():
"""
# Create a few entities
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
- B = TechnicalPerson.get_or_create(
- {"name": "Happy", "expertise": "Unicorns"}
- )[0]
- C = TechnicalPerson.get_or_create(
- {"name": "Sleepy", "expertise": "Pillows"}
- )[0]
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
+ B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0]
+ C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0]
# Add connections
A.friends_with.connect(B)
B.friends_with.connect(C)
C.friends_with.connect(A)
+ test = A.friends_with
+
# If A is friends with B, then A's friends_with objects should be
# TechnicalPerson (!NOT basePerson!)
- assert type(A.friends_with[0]) is TechnicalPerson
+ assert type((A.friends_with)[0]) is TechnicalPerson
A.delete()
B.delete()
C.delete()
+@mark_sync_test
def test_recursive_automatic_result_resolution():
"""
Node objects are instantiated to native Python objects, both at the top
@@ -120,21 +130,17 @@ def test_recursive_automatic_result_resolution():
"""
# Create a few entities
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpier", "expertise": "Grumpiness"}
- )[0]
- B = TechnicalPerson.get_or_create(
- {"name": "Happier", "expertise": "Grumpiness"}
- )[0]
- C = TechnicalPerson.get_or_create(
- {"name": "Sleepier", "expertise": "Pillows"}
- )[0]
- D = TechnicalPerson.get_or_create(
- {"name": "Sneezier", "expertise": "Pillows"}
+ A = (
+ TechnicalPerson.get_or_create({"name": "Grumpier", "expertise": "Grumpiness"})
)[0]
+ B = (TechnicalPerson.get_or_create({"name": "Happier", "expertise": "Grumpiness"}))[
+ 0
+ ]
+ C = (TechnicalPerson.get_or_create({"name": "Sleepier", "expertise": "Pillows"}))[0]
+ D = (TechnicalPerson.get_or_create({"name": "Sneezier", "expertise": "Pillows"}))[0]
# Retrieve mixed results, both at the top level and nested
- L, _ = neomodel.db.cypher_query(
+ L, _ = db.cypher_query(
"MATCH (a:TechnicalPerson) "
"WHERE a.expertise='Grumpiness' "
"WITH collect(a) as Alpha "
@@ -158,6 +164,7 @@ def test_recursive_automatic_result_resolution():
D.delete()
+@mark_sync_test
def test_validation_with_inheritance_from_db():
"""
Objects descending from the specified class of a relationship's end-node are
@@ -166,22 +173,22 @@ def test_validation_with_inheritance_from_db():
# Create a few entities
# Technical Persons
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
- B = TechnicalPerson.get_or_create(
- {"name": "Happy", "expertise": "Unicorns"}
- )[0]
- C = TechnicalPerson.get_or_create(
- {"name": "Sleepy", "expertise": "Pillows"}
- )[0]
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
+ B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0]
+ C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0]
# Pilot Persons
- D = PilotPerson.get_or_create(
- {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"}
+ D = (
+ PilotPerson.get_or_create(
+ {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"}
+ )
)[0]
- E = PilotPerson.get_or_create(
- {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"}
+ E = (
+ PilotPerson.get_or_create(
+ {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"}
+ )
)[0]
# TechnicalPersons can befriend PilotPersons and vice-versa and that's fine
@@ -201,13 +208,13 @@ def test_validation_with_inheritance_from_db():
# This now means that friends_with of a TechnicalPerson can
# either be TechnicalPerson or Pilot Person (!NOT basePerson!)
- assert (type(A.friends_with[0]) is TechnicalPerson) or (
- type(A.friends_with[0]) is PilotPerson
+ assert (type((A.friends_with)[0]) is TechnicalPerson) or (
+ type((A.friends_with)[0]) is PilotPerson
)
- assert (type(A.friends_with[1]) is TechnicalPerson) or (
- type(A.friends_with[1]) is PilotPerson
+ assert (type((A.friends_with)[1]) is TechnicalPerson) or (
+ type((A.friends_with)[1]) is PilotPerson
)
- assert type(D.friends_with[0]) is PilotPerson
+ assert type((D.friends_with)[0]) is PilotPerson
A.delete()
B.delete()
@@ -216,6 +223,7 @@ def test_validation_with_inheritance_from_db():
E.delete()
+@mark_sync_test
def test_validation_enforcement_to_db():
"""
If a connection between wrong types is attempted, raise an exception
@@ -223,22 +231,22 @@ def test_validation_enforcement_to_db():
# Create a few entities
# Technical Persons
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
- B = TechnicalPerson.get_or_create(
- {"name": "Happy", "expertise": "Unicorns"}
- )[0]
- C = TechnicalPerson.get_or_create(
- {"name": "Sleepy", "expertise": "Pillows"}
- )[0]
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
+ B = (TechnicalPerson.get_or_create({"name": "Happy", "expertise": "Unicorns"}))[0]
+ C = (TechnicalPerson.get_or_create({"name": "Sleepy", "expertise": "Pillows"}))[0]
# Pilot Persons
- D = PilotPerson.get_or_create(
- {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"}
+ D = (
+ PilotPerson.get_or_create(
+ {"name": "Porco Rosso", "airplane": "Savoia-Marchetti"}
+ )
)[0]
- E = PilotPerson.get_or_create(
- {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"}
+ E = (
+ PilotPerson.get_or_create(
+ {"name": "Jack Dalton", "airplane": "Beechcraft Model 18"}
+ )
)[0]
# Some Person
@@ -265,6 +273,7 @@ def test_validation_enforcement_to_db():
F.delete()
+@mark_sync_test
def test_failed_result_resolution():
"""
A Neo4j driver node FROM the database contains labels that are unaware to
@@ -273,39 +282,39 @@ def test_failed_result_resolution():
"""
class RandomPerson(BasePerson):
- randomness = neomodel.FloatProperty(default=random.random)
+ randomness = FloatProperty(default=random.random)
# A Technical Person...
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
# A Random Person...
- B = RandomPerson.get_or_create({"name": "Mad Hatter"})[0]
+ B = (RandomPerson.get_or_create({"name": "Mad Hatter"}))[0]
A.friends_with.connect(B)
# Simulate the condition where the definition of class RandomPerson is not
# known yet.
- del neomodel.db._NODE_CLASS_REGISTRY[
- frozenset(["RandomPerson", "BasePerson"])
- ]
+ del db._NODE_CLASS_REGISTRY[frozenset(["RandomPerson", "BasePerson"])]
# Now try to instantiate a RandomPerson
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
with pytest.raises(
- neomodel.exceptions.NodeClassNotDefined,
+ NodeClassNotDefined,
match=r"Node with labels .* does not resolve to any of the known objects.*",
):
- for some_friend in A.friends_with:
+ friends = A.friends_with.all()
+ for some_friend in friends:
print(some_friend.name)
A.delete()
B.delete()
+@mark_sync_test
def test_node_label_mismatch():
"""
A Neo4j driver node FROM the database contains a superset of the known
@@ -313,23 +322,21 @@ def test_node_label_mismatch():
"""
class SuperTechnicalPerson(TechnicalPerson):
- superness = neomodel.FloatProperty(default=1.0)
+ superness = FloatProperty(default=1.0)
class UltraTechnicalPerson(SuperTechnicalPerson):
- ultraness = neomodel.FloatProperty(default=3.1415928)
+ ultraness = FloatProperty(default=3.1415928)
# Create a TechnicalPerson...
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
# ...that is connected to an UltraTechnicalPerson
- F = UltraTechnicalPerson(
- name="Chewbaka", expertise="Aarrr wgh ggwaaah"
- ).save()
+ F = UltraTechnicalPerson(name="Chewbaka", expertise="Aarrr wgh ggwaaah").save()
A.friends_with.connect(F)
# Forget about the UltraTechnicalPerson
- del neomodel.db._NODE_CLASS_REGISTRY[
+ del db._NODE_CLASS_REGISTRY[
frozenset(
[
"UltraTechnicalPerson",
@@ -343,17 +350,18 @@ class UltraTechnicalPerson(SuperTechnicalPerson):
# Recall a TechnicalPerson and enumerate its friends.
# One of them is UltraTechnicalPerson which would be returned as a valid
# node to a friends_with query but is currently unknown to the node class registry.
- A = TechnicalPerson.get_or_create(
- {"name": "Grumpy", "expertise": "Grumpiness"}
- )[0]
- with pytest.raises(neomodel.exceptions.NodeClassNotDefined):
- for some_friend in A.friends_with:
+ A = (TechnicalPerson.get_or_create({"name": "Grumpy", "expertise": "Grumpiness"}))[
+ 0
+ ]
+ with pytest.raises(NodeClassNotDefined):
+ friends = A.friends_with.all()
+ for some_friend in friends:
print(some_friend.name)
def test_attempted_class_redefinition():
"""
- A neomodel.StructuredNode class is attempted to be redefined.
+ A StructuredNode class is attempted to be redefined.
"""
def redefine_class_locally():
@@ -361,23 +369,22 @@ def redefine_class_locally():
# SomePerson here.
# The internal structure of the SomePerson entity does not matter at all here.
class SomePerson(BaseOtherPerson):
- uid = neomodel.UniqueIdProperty()
+ uid = UniqueIdProperty()
with pytest.raises(
- neomodel.exceptions.NodeClassAlreadyDefined,
+ NodeClassAlreadyDefined,
match=r"Class .* with labels .* already defined:.*",
):
redefine_class_locally()
+@mark_sync_test
def test_relationship_result_resolution():
"""
A query returning a "Relationship" object can now instantiate it to a data model class
"""
# Test specific data
- A = PilotPerson(
- name="Zantford Granville", airplane="Gee Bee Model R"
- ).save()
+ A = PilotPerson(name="Zantford Granville", airplane="Gee Bee Model R").save()
B = PilotPerson(name="Thomas Granville", airplane="Gee Bee Model R").save()
C = PilotPerson(name="Robert Granville", airplane="Gee Bee Model R").save()
D = PilotPerson(name="Mark Granville", airplane="Gee Bee Model R").save()
@@ -388,7 +395,7 @@ def test_relationship_result_resolution():
C.friends_with.connect(D)
D.friends_with.connect(E)
- query_data = neomodel.db.cypher_query(
+ query_data = db.cypher_query(
"MATCH (a:PilotPerson)-[r:FRIENDS_WITH]->(b:PilotPerson) "
"WHERE a.airplane='Gee Bee Model R' and b.airplane='Gee Bee Model R' "
"RETURN DISTINCT r",
@@ -399,6 +406,7 @@ def test_relationship_result_resolution():
assert isinstance(query_data[0][0][0], PersonalRelationship)
+@mark_sync_test
def test_properly_inherited_relationship():
"""
A relationship class extends an existing relationship model that must extended the same previously associated
@@ -408,11 +416,11 @@ def test_properly_inherited_relationship():
# Extends an existing relationship by adding the "relationship_strength" attribute.
# `ExtendedPersonalRelationship` will now substitute `PersonalRelationship` EVERYWHERE in the system.
class ExtendedPersonalRelationship(PersonalRelationship):
- relationship_strength = neomodel.FloatProperty(default=random.random)
+ relationship_strength = FloatProperty(default=random.random)
# Extends SomePerson, establishes "enriched" relationships with any BaseOtherPerson
class ExtendedSomePerson(SomePerson):
- friends_with = neomodel.RelationshipTo(
+ friends_with = RelationshipTo(
"BaseOtherPerson",
"FRIENDS_WITH",
model=ExtendedPersonalRelationship,
@@ -426,7 +434,7 @@ class ExtendedSomePerson(SomePerson):
A.friends_with.connect(B)
A.friends_with.connect(C)
- query_data = neomodel.db.cypher_query(
+ query_data = db.cypher_query(
"MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) "
"RETURN DISTINCT r",
resolve_objects=True,
@@ -441,20 +449,21 @@ def test_improperly_inherited_relationship():
:return:
"""
- class NewRelationship(neomodel.StructuredRel):
- profile_match_factor = neomodel.FloatProperty()
+ class NewRelationship(StructuredRel):
+ profile_match_factor = FloatProperty()
with pytest.raises(
- neomodel.RelationshipClassRedefined,
+ RelationshipClassRedefined,
match=r"Relationship of type .* redefined as .*",
):
class NewSomePerson(SomePerson):
- friends_with = neomodel.RelationshipTo(
+ friends_with = RelationshipTo(
"BaseOtherPerson", "FRIENDS_WITH", model=NewRelationship
)
+@mark_sync_test
def test_resolve_inexistent_relationship():
"""
Attempting to resolve an inexistent relationship should raise an exception
@@ -462,13 +471,13 @@ def test_resolve_inexistent_relationship():
"""
# Forget about the FRIENDS_WITH Relationship.
- del neomodel.db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])]
+ del db._NODE_CLASS_REGISTRY[frozenset(["FRIENDS_WITH"])]
with pytest.raises(
- neomodel.RelationshipClassNotDefined,
+ RelationshipClassNotDefined,
match=r"Relationship of type .* does not resolve to any of the known objects.*",
):
- query_data = neomodel.db.cypher_query(
+ query_data = db.cypher_query(
"MATCH (:ExtendedSomePerson)-[r:FRIENDS_WITH]->(:ExtendedSomePerson) "
"RETURN DISTINCT r",
resolve_objects=True,
diff --git a/test/test_issue600.py b/test/sync_/test_issue600.py
similarity index 61%
rename from test/test_issue600.py
rename to test/sync_/test_issue600.py
index 6851efd2..f6b5a10b 100644
--- a/test/test_issue600.py
+++ b/test/sync_/test_issue600.py
@@ -4,13 +4,9 @@
The issue is outlined here: https://github.com/neo4j-contrib/neomodel/issues/600
"""
-import datetime
-import os
-import random
+from test._async_compat import mark_sync_test
-import pytest
-
-import neomodel
+from neomodel import Relationship, StructuredNode, StructuredRel
try:
basestring
@@ -18,7 +14,7 @@
basestring = str
-class Class1(neomodel.StructuredRel):
+class Class1(StructuredRel):
pass
@@ -30,36 +26,37 @@ class SubClass2(Class1):
pass
-class RelationshipDefinerSecondSibling(neomodel.StructuredNode):
- rel_1 = neomodel.Relationship(
+class RelationshipDefinerSecondSibling(StructuredNode):
+ rel_1 = Relationship(
"RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=Class1
)
- rel_2 = neomodel.Relationship(
+ rel_2 = Relationship(
"RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass1
)
- rel_3 = neomodel.Relationship(
+ rel_3 = Relationship(
"RelationshipDefinerSecondSibling", "SOME_REL_LABEL", model=SubClass2
)
-class RelationshipDefinerParentLast(neomodel.StructuredNode):
- rel_2 = neomodel.Relationship(
+class RelationshipDefinerParentLast(StructuredNode):
+ rel_2 = Relationship(
"RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass1
)
- rel_3 = neomodel.Relationship(
+ rel_3 = Relationship(
"RelationshipDefinerParentLast", "SOME_REL_LABEL", model=SubClass2
)
- rel_1 = neomodel.Relationship(
+ rel_1 = Relationship(
"RelationshipDefinerParentLast", "SOME_REL_LABEL", model=Class1
)
# Test cases
+@mark_sync_test
def test_relationship_definer_second_sibling():
# Create a few entities
- A = RelationshipDefinerSecondSibling.get_or_create({})[0]
- B = RelationshipDefinerSecondSibling.get_or_create({})[0]
- C = RelationshipDefinerSecondSibling.get_or_create({})[0]
+ A = (RelationshipDefinerSecondSibling.get_or_create({}))[0]
+ B = (RelationshipDefinerSecondSibling.get_or_create({}))[0]
+ C = (RelationshipDefinerSecondSibling.get_or_create({}))[0]
# Add connections
A.rel_1.connect(B)
@@ -72,11 +69,12 @@ def test_relationship_definer_second_sibling():
C.delete()
+@mark_sync_test
def test_relationship_definer_parent_last():
# Create a few entities
- A = RelationshipDefinerParentLast.get_or_create({})[0]
- B = RelationshipDefinerParentLast.get_or_create({})[0]
- C = RelationshipDefinerParentLast.get_or_create({})[0]
+ A = (RelationshipDefinerParentLast.get_or_create({}))[0]
+ B = (RelationshipDefinerParentLast.get_or_create({}))[0]
+ C = (RelationshipDefinerParentLast.get_or_create({}))[0]
# Add connections
A.rel_1.connect(B)
diff --git a/test/test_label_drop.py b/test/sync_/test_label_drop.py
similarity index 87%
rename from test/test_label_drop.py
rename to test/sync_/test_label_drop.py
index 389d19e0..e4834817 100644
--- a/test/test_label_drop.py
+++ b/test/sync_/test_label_drop.py
@@ -1,9 +1,8 @@
-from neo4j.exceptions import ClientError
+from test._async_compat import mark_sync_test
-from neomodel import StringProperty, StructuredNode, config
-from neomodel.core import db, remove_all_labels
+from neo4j.exceptions import ClientError
-config.AUTO_INSTALL_LABELS = True
+from neomodel import StringProperty, StructuredNode, db
class ConstraintAndIndex(StructuredNode):
@@ -11,14 +10,16 @@ class ConstraintAndIndex(StructuredNode):
last_name = StringProperty(index=True)
+@mark_sync_test
def test_drop_labels():
+ db.install_labels(ConstraintAndIndex)
constraints_before = db.list_constraints()
indexes_before = db.list_indexes(exclude_token_lookup=True)
assert len(constraints_before) > 0
assert len(indexes_before) > 0
- remove_all_labels()
+ db.remove_all_labels()
constraints = db.list_constraints()
indexes = db.list_indexes(exclude_token_lookup=True)
diff --git a/test/test_label_install.py b/test/sync_/test_label_install.py
similarity index 78%
rename from test/test_label_install.py
rename to test/sync_/test_label_install.py
index 46f55467..14bfe107 100644
--- a/test/test_label_install.py
+++ b/test/sync_/test_label_install.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
import pytest
from neomodel import (
@@ -6,15 +8,10 @@
StructuredNode,
StructuredRel,
UniqueIdProperty,
- config,
- install_all_labels,
- install_labels,
+ db,
)
-from neomodel.core import db, drop_constraints
from neomodel.exceptions import ConstraintValidationFailed, FeatureNotSupported
-config.AUTO_INSTALL_LABELS = False
-
class NodeWithIndex(StructuredNode):
name = StringProperty(index=True)
@@ -47,24 +44,13 @@ class SomeNotUniqueNode(StructuredNode):
id_ = UniqueIdProperty(db_property="id")
-config.AUTO_INSTALL_LABELS = True
-
-
-def test_labels_were_not_installed():
- bob = NodeWithConstraint(name="bob").save()
- bob2 = NodeWithConstraint(name="bob").save()
- bob3 = NodeWithConstraint(name="bob").save()
- assert bob.element_id != bob3.element_id
-
- for n in NodeWithConstraint.nodes.all():
- n.delete()
-
-
+@mark_sync_test
def test_install_all():
- drop_constraints()
- install_labels(AbstractNode)
+ db.drop_constraints()
+ db.drop_indexes()
+ db.install_labels(AbstractNode)
# run install all labels
- install_all_labels()
+ db.install_all_labels()
indexes = db.list_indexes()
index_names = [index["name"] for index in indexes]
@@ -79,25 +65,28 @@ def test_install_all():
_drop_constraints_for_label_and_property("NoConstraintsSetup", "name")
+@mark_sync_test
def test_install_label_twice(capsys):
+ db.drop_constraints()
+ db.drop_indexes()
expected_std_out = (
"{code: Neo.ClientError.Schema.EquivalentSchemaRuleAlreadyExists}"
)
- install_labels(AbstractNode)
- install_labels(AbstractNode)
+ db.install_labels(AbstractNode)
+ db.install_labels(AbstractNode)
- install_labels(NodeWithIndex)
- install_labels(NodeWithIndex, quiet=False)
+ db.install_labels(NodeWithIndex)
+ db.install_labels(NodeWithIndex, quiet=False)
captured = capsys.readouterr()
assert expected_std_out in captured.out
- install_labels(NodeWithConstraint)
- install_labels(NodeWithConstraint, quiet=False)
+ db.install_labels(NodeWithConstraint)
+ db.install_labels(NodeWithConstraint, quiet=False)
captured = capsys.readouterr()
assert expected_std_out in captured.out
- install_labels(OtherNodeWithRelationship)
- install_labels(OtherNodeWithRelationship, quiet=False)
+ db.install_labels(OtherNodeWithRelationship)
+ db.install_labels(OtherNodeWithRelationship, quiet=False)
captured = capsys.readouterr()
assert expected_std_out in captured.out
@@ -111,15 +100,16 @@ class OtherNodeWithUniqueIndexRelationship(StructuredNode):
NodeWithRelationship, "UNIQUE_INDEX_REL", model=UniqueIndexRelationship
)
- install_labels(OtherNodeWithUniqueIndexRelationship)
- install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False)
+ db.install_labels(OtherNodeWithUniqueIndexRelationship)
+ db.install_labels(OtherNodeWithUniqueIndexRelationship, quiet=False)
captured = capsys.readouterr()
assert expected_std_out in captured.out
+@mark_sync_test
def test_install_labels_db_property(capsys):
- drop_constraints()
- install_labels(SomeNotUniqueNode, quiet=False)
+ db.drop_constraints()
+ db.install_labels(SomeNotUniqueNode, quiet=False)
captured = capsys.readouterr()
assert "id" in captured.out
# make sure that the id_ constraint doesn't exist
@@ -131,8 +121,11 @@ def test_install_labels_db_property(capsys):
_drop_constraints_for_label_and_property("SomeNotUniqueNode", "id")
-@pytest.mark.skipif(db.version_is_higher_than("5.7"), reason="Not supported before 5.7")
+@mark_sync_test
def test_relationship_unique_index_not_supported():
+ if db.version_is_higher_than("5.7"):
+ pytest.skip("Not supported before 5.7")
+
class UniqueIndexRelationship(StructuredRel):
name = StringProperty(unique_index=True)
@@ -150,9 +143,14 @@ class NodeWithUniqueIndexRelationship(StructuredNode):
model=UniqueIndexRelationship,
)
+ db.install_labels(NodeWithUniqueIndexRelationship)
-@pytest.mark.skipif(not db.version_is_higher_than("5.7"), reason="Supported from 5.7")
+
+@mark_sync_test
def test_relationship_unique_index():
+ if not db.version_is_higher_than("5.7"):
+ pytest.skip("Not supported before 5.7")
+
class UniqueIndexRelationshipBis(StructuredRel):
name = StringProperty(unique_index=True)
@@ -166,7 +164,7 @@ class NodeWithUniqueIndexRelationship(StructuredNode):
model=UniqueIndexRelationshipBis,
)
- install_labels(UniqueIndexRelationshipBis)
+ db.install_labels(NodeWithUniqueIndexRelationship)
node1 = NodeWithUniqueIndexRelationship().save()
node2 = TargetNodeForUniqueIndexRelationship().save()
node3 = TargetNodeForUniqueIndexRelationship().save()
diff --git a/test/test_match_api.py b/test/sync_/test_match_api.py
similarity index 82%
rename from test/test_match_api.py
rename to test/sync_/test_match_api.py
index ee6b337e..fe63badd 100644
--- a/test/test_match_api.py
+++ b/test/sync_/test_match_api.py
@@ -1,4 +1,5 @@
from datetime import datetime
+from test._async_compat import mark_sync_test
from pytest import raises
@@ -13,8 +14,9 @@
StructuredNode,
StructuredRel,
)
+from neomodel._async_compat.util import Util
from neomodel.exceptions import MultipleNodesReturned
-from neomodel.match import NodeSet, Optional, QueryBuilder, Traversal
+from neomodel.sync_.match import NodeSet, Optional, QueryBuilder, Traversal
class SupplierRel(StructuredRel):
@@ -45,6 +47,7 @@ class Extension(StructuredNode):
extension = RelationshipTo("Extension", "extension")
+@mark_sync_test
def test_filter_exclude_via_labels():
Coffee(name="Java", price=99).save()
@@ -71,6 +74,7 @@ def test_filter_exclude_via_labels():
assert results[0].name == "Kenco"
+@mark_sync_test
def test_simple_has_via_label():
nescafe = Coffee(name="Nescafe", price=99).save()
tesco = Supplier(name="Tesco", delivery_cost=2).save()
@@ -91,6 +95,7 @@ def test_simple_has_via_label():
assert "NOT" in qb._ast.where[0]
+@mark_sync_test
def test_get():
Coffee(name="1", price=3).save()
assert Coffee.nodes.get(name="1")
@@ -104,6 +109,7 @@ def test_get():
Coffee.nodes.get(price=3)
+@mark_sync_test
def test_simple_traverse_with_filter():
nescafe = Coffee(name="Nescafe2", price=99).save()
tesco = Supplier(name="Sainsburys", delivery_cost=2).save()
@@ -111,7 +117,8 @@ def test_simple_traverse_with_filter():
qb = QueryBuilder(NodeSet(source=nescafe).suppliers.match(since__lt=datetime.now()))
- results = qb.build_ast()._execute()
+ _ast = qb.build_ast()
+ results = _ast._execute()
assert qb._ast.lookup
assert qb._ast.match
@@ -120,6 +127,7 @@ def test_simple_traverse_with_filter():
assert results[0].name == "Sainsburys"
+@mark_sync_test
def test_double_traverse():
nescafe = Coffee(name="Nescafe plus", price=99).save()
tesco = Supplier(name="Asda", delivery_cost=2).save()
@@ -135,19 +143,23 @@ def test_double_traverse():
assert results[1].name == "Nescafe plus"
+@mark_sync_test
def test_count():
Coffee(name="Nescafe Gold", price=99).save()
- count = QueryBuilder(NodeSet(source=Coffee)).build_ast()._count()
+ ast = QueryBuilder(NodeSet(source=Coffee)).build_ast()
+ count = ast._count()
assert count > 0
Coffee(name="Kawa", price=27).save()
node_set = NodeSet(source=Coffee)
node_set.skip = 1
node_set.limit = 1
- count = QueryBuilder(node_set).build_ast()._count()
+ ast = QueryBuilder(node_set).build_ast()
+ count = ast._count()
assert count == 1
+@mark_sync_test
def test_len_and_iter_and_bool():
iterations = 0
@@ -162,6 +174,7 @@ def test_len_and_iter_and_bool():
assert len(Coffee.nodes) == 0
+@mark_sync_test
def test_slice():
for c in Coffee.nodes:
c.delete()
@@ -170,13 +183,22 @@ def test_slice():
Coffee(name="Britains finest").save()
Coffee(name="Japans finest").save()
- assert len(list(Coffee.nodes.all()[1:])) == 2
- assert len(list(Coffee.nodes.all()[:1])) == 1
- assert isinstance(Coffee.nodes[1], Coffee)
- assert isinstance(Coffee.nodes[0], Coffee)
- assert len(list(Coffee.nodes.all()[1:2])) == 1
-
-
+ # Branching tests because async needs extra brackets
+ if Util.is_async_code:
+ assert len(list((Coffee.nodes)[1:])) == 2
+ assert len(list((Coffee.nodes)[:1])) == 1
+ assert isinstance((Coffee.nodes)[1], Coffee)
+ assert isinstance((Coffee.nodes)[0], Coffee)
+ assert len(list((Coffee.nodes)[1:2])) == 1
+ else:
+ assert len(list(Coffee.nodes[1:])) == 2
+ assert len(list(Coffee.nodes[:1])) == 1
+ assert isinstance(Coffee.nodes[1], Coffee)
+ assert isinstance(Coffee.nodes[0], Coffee)
+ assert len(list(Coffee.nodes[1:2])) == 1
+
+
+@mark_sync_test
def test_issue_208():
# calls to match persist across queries.
@@ -191,13 +213,16 @@ def test_issue_208():
assert len(b.suppliers.match(courier="dhl"))
+@mark_sync_test
def test_issue_589():
node1 = Extension().save()
node2 = Extension().save()
+ assert node2 not in node1.extension
node1.extension.connect(node2)
- assert node2 in node1.extension.all()
+ assert node2 in node1.extension
+@mark_sync_test
def test_contains():
expensive = Coffee(price=1000, name="Pricey").save()
asda = Coffee(name="Asda", price=1).save()
@@ -206,14 +231,21 @@ def test_contains():
assert asda not in Coffee.nodes.filter(price__gt=999)
# bad value raises
- with raises(ValueError):
- 2 in Coffee.nodes
+ with raises(ValueError, match=r"Expecting StructuredNode instance"):
+ if Util.is_async_code:
+ assert Coffee.nodes.__contains__(2)
+ else:
+ assert 2 in Coffee.nodes
# unsaved
- with raises(ValueError):
- Coffee() in Coffee.nodes
+ with raises(ValueError, match=r"Unsaved node"):
+ if Util.is_async_code:
+ assert Coffee.nodes.__contains__(Coffee())
+ else:
+ assert Coffee() in Coffee.nodes
+@mark_sync_test
def test_order_by():
for c in Coffee.nodes:
c.delete()
@@ -222,8 +254,12 @@ def test_order_by():
c2 = Coffee(name="Britains finest", price=10).save()
c3 = Coffee(name="Japans finest", price=35).save()
- assert Coffee.nodes.order_by("price").all()[0].price == 5
- assert Coffee.nodes.order_by("-price").all()[0].price == 35
+ if Util.is_async_code:
+ assert ((Coffee.nodes.order_by("price"))[0]).price == 5
+ assert ((Coffee.nodes.order_by("-price"))[0]).price == 35
+ else:
+ assert (Coffee.nodes.order_by("price")[0]).price == 5
+ assert (Coffee.nodes.order_by("-price")[0]).price == 35
ns = Coffee.nodes.order_by("-price")
qb = QueryBuilder(ns).build_ast()
@@ -248,12 +284,13 @@ def test_order_by():
l.coffees.connect(c2)
l.coffees.connect(c3)
- ordered_n = [n for n in l.coffees.order_by("name").all()]
+ ordered_n = [n for n in l.coffees.order_by("name")]
assert ordered_n[0] == c2
assert ordered_n[1] == c1
assert ordered_n[2] == c3
+@mark_sync_test
def test_extra_filters():
for c in Coffee.nodes:
c.delete()
@@ -263,22 +300,22 @@ def test_extra_filters():
c3 = Coffee(name="Japans finest", price=35, id_=3).save()
c4 = Coffee(name="US extra-fine", price=None, id_=4).save()
- coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5]).all()
+ coffees_5_10 = Coffee.nodes.filter(price__in=[10, 5])
assert len(coffees_5_10) == 2, "unexpected number of results"
assert c1 in coffees_5_10, "doesnt contain 5 price coffee"
assert c2 in coffees_5_10, "doesnt contain 10 price coffee"
- finest_coffees = Coffee.nodes.filter(name__iendswith=" Finest").all()
+ finest_coffees = Coffee.nodes.filter(name__iendswith=" Finest")
assert len(finest_coffees) == 3, "unexpected number of results"
assert c1 in finest_coffees, "doesnt contain 1st finest coffee"
assert c2 in finest_coffees, "doesnt contain 2nd finest coffee"
assert c3 in finest_coffees, "doesnt contain 3rd finest coffee"
- unpriced_coffees = Coffee.nodes.filter(price__isnull=True).all()
+ unpriced_coffees = Coffee.nodes.filter(price__isnull=True)
assert len(unpriced_coffees) == 1, "unexpected number of results"
assert c4 in unpriced_coffees, "doesnt contain unpriced coffee"
- coffees_with_id_gte_3 = Coffee.nodes.filter(id___gte=3).all()
+ coffees_with_id_gte_3 = Coffee.nodes.filter(id___gte=3)
assert len(coffees_with_id_gte_3) == 2, "unexpected number of results"
assert c3 in coffees_with_id_gte_3
assert c4 in coffees_with_id_gte_3
@@ -317,6 +354,7 @@ def test_traversal_definition_keys_are_valid():
)
+@mark_sync_test
def test_empty_filters():
"""Test this case:
```
@@ -353,6 +391,7 @@ def test_empty_filters():
), "doesnt contain c1 in ``filter_empty_filter``"
+@mark_sync_test
def test_q_filters():
# Test where no children and self.connector != conn ?
for c in Coffee.nodes:
@@ -418,7 +457,7 @@ def test_q_filters():
combined_coffees = Coffee.nodes.filter(
Q(price=35), Q(name="Latte") | Q(name="Cappuccino")
- )
+ ).all()
assert len(combined_coffees) == 2
assert c5 in combined_coffees
assert c6 in combined_coffees
@@ -443,6 +482,7 @@ def test_qbase():
assert len(test_hash) == 1
+@mark_sync_test
def test_traversal_filter_left_hand_statement():
nescafe = Coffee(name="Nescafe2", price=99).save()
nescafe_gold = Coffee(name="Nescafe gold", price=11).save()
@@ -462,6 +502,7 @@ def test_traversal_filter_left_hand_statement():
assert lidl in lidl_supplier
+@mark_sync_test
def test_fetch_relations():
arabica = Species(name="Arabica").save()
robusta = Species(name="Robusta").save()
@@ -491,14 +532,27 @@ def test_fetch_relations():
)
assert result[0][0] is None
- # len() should only consider Suppliers
- count = len(
- Supplier.nodes.filter(name="Sainsburys")
- .fetch_relations("coffees__species")
- .all()
- )
- assert count == 1
+ if Util.is_async_code:
+ count = (
+ Supplier.nodes.filter(name="Sainsburys")
+ .fetch_relations("coffees__species")
+ .__len__()
+ )
+ assert count == 1
- assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter(
- name="Sainsburys"
- )
+ assert (
+ Supplier.nodes.fetch_relations("coffees__species")
+ .filter(name="Sainsburys")
+ .__contains__(tesco)
+ )
+ else:
+ count = len(
+ Supplier.nodes.filter(name="Sainsburys")
+ .fetch_relations("coffees__species")
+ .all()
+ )
+ assert count == 1
+
+ assert tesco in Supplier.nodes.fetch_relations("coffees__species").filter(
+ name="Sainsburys"
+ )
diff --git a/test/test_migration_neo4j_5.py b/test/sync_/test_migration_neo4j_5.py
similarity index 84%
rename from test/test_migration_neo4j_5.py
rename to test/sync_/test_migration_neo4j_5.py
index a4730329..83e090cb 100644
--- a/test/test_migration_neo4j_5.py
+++ b/test/sync_/test_migration_neo4j_5.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
import pytest
from neomodel import (
@@ -6,8 +8,8 @@
StringProperty,
StructuredNode,
StructuredRel,
+ db,
)
-from neomodel.core import db
class Album(StructuredNode):
@@ -23,25 +25,27 @@ class Band(StructuredNode):
released = RelationshipTo(Album, relation_type="RELEASED", model=Released)
+@mark_sync_test
def test_read_elements_id():
the_hives = Band(name="The Hives").save()
lex_hives = Album(name="Lex Hives").save()
released_rel = the_hives.released.connect(lex_hives)
# Validate element_id properties
- assert lex_hives.element_id == the_hives.released.single().element_id
+ assert lex_hives.element_id == (the_hives.released.single()).element_id
assert released_rel._start_node_element_id == the_hives.element_id
assert released_rel._end_node_element_id == lex_hives.element_id
# Validate id properties
# Behaviour is dependent on Neo4j version
- if db.database_version.startswith("4"):
+ db_version = db.database_version
+ if db_version.startswith("4"):
# Nodes' ids
assert lex_hives.id == int(lex_hives.element_id)
- assert lex_hives.id == the_hives.released.single().id
+ assert lex_hives.id == (the_hives.released.single()).id
# Relationships' ids
- assert isinstance(released_rel.element_id, int)
- assert released_rel.element_id == released_rel.id
+ assert isinstance(released_rel.element_id, str)
+ assert int(released_rel.element_id) == released_rel.id
assert released_rel._start_node_id == int(the_hives.element_id)
assert released_rel._end_node_id == int(lex_hives.element_id)
else:
diff --git a/test/test_models.py b/test/sync_/test_models.py
similarity index 90%
rename from test/test_models.py
rename to test/sync_/test_models.py
index 827c705a..3698b612 100644
--- a/test/test_models.py
+++ b/test/sync_/test_models.py
@@ -1,6 +1,7 @@
from __future__ import print_function
from datetime import datetime
+from test._async_compat import mark_sync_test
from pytest import raises
@@ -10,9 +11,8 @@
StringProperty,
StructuredNode,
StructuredRel,
- install_labels,
+ db,
)
-from neomodel.core import db
from neomodel.exceptions import RequiredProperty, UniqueProperty
@@ -33,6 +33,7 @@ class NodeWithoutProperty(StructuredNode):
pass
+@mark_sync_test
def test_issue_233():
class BaseIssue233(StructuredNode):
__abstract_node__ = True
@@ -52,22 +53,19 @@ def test_issue_72():
assert user.age is None
+@mark_sync_test
def test_required():
- try:
+ with raises(RequiredProperty):
User(age=3).save()
- except RequiredProperty:
- assert True
- else:
- assert False
def test_repr_and_str():
u = User(email="robin@test.com", age=3)
- print(repr(u))
- print(str(u))
- assert True
+ assert repr(u) == ""
+ assert str(u) == "{'email': 'robin@test.com', 'age': 3}"
+@mark_sync_test
def test_get_and_get_or_none():
u = User(email="robin@test.com", age=3)
assert u.save()
@@ -82,6 +80,7 @@ def test_get_and_get_or_none():
assert n is None
+@mark_sync_test
def test_first_and_first_or_none():
u = User(email="matt@test.com", age=24)
assert u.save()
@@ -103,9 +102,10 @@ def test_bare_init_without_save():
If a node model is initialised without being saved, accessing its `element_id` should
return None.
"""
- assert(User().element_id is None)
+ assert User().element_id is None
+@mark_sync_test
def test_save_to_model():
u = User(email="jim@test.com", age=3)
assert u.save()
@@ -114,24 +114,28 @@ def test_save_to_model():
assert u.age == 3
+@mark_sync_test
def test_save_node_without_properties():
n = NodeWithoutProperty()
assert n.save()
assert n.element_id is not None
+@mark_sync_test
def test_unique():
- install_labels(User)
+ db.install_labels(User)
User(email="jim1@test.com", age=3).save()
with raises(UniqueProperty):
User(email="jim1@test.com", age=3).save()
+@mark_sync_test
def test_update_unique():
u = User(email="jimxx@test.com", age=3).save()
u.save() # this shouldn't fail
+@mark_sync_test
def test_update():
user = User(email="jim2@test.com", age=3).save()
assert user
@@ -142,6 +146,7 @@ def test_update():
assert jim.email == "jim2000@test.com"
+@mark_sync_test
def test_save_through_magic_property():
user = User(email_alias="blah@test.com", age=8).save()
assert user.email_alias == "blah@test.com"
@@ -163,19 +168,21 @@ class Customer2(StructuredNode):
age = IntegerProperty(index=True)
+@mark_sync_test
def test_not_updated_on_unique_error():
- install_labels(Customer2)
+ db.install_labels(Customer2)
Customer2(email="jim@bob.com", age=7).save()
test = Customer2(email="jim1@bob.com", age=2).save()
test.email = "jim@bob.com"
with raises(UniqueProperty):
test.save()
- customers = Customer2.nodes.all()
+ customers = Customer2.nodes
assert customers[0].email != customers[1].email
- assert Customer2.nodes.get(email="jim@bob.com").age == 7
- assert Customer2.nodes.get(email="jim1@bob.com").age == 2
+ assert (Customer2.nodes.get(email="jim@bob.com")).age == 7
+ assert (Customer2.nodes.get(email="jim1@bob.com")).age == 2
+@mark_sync_test
def test_label_not_inherited():
class Customer3(Customer2):
address = StringProperty()
@@ -191,6 +198,7 @@ class Customer3(Customer2):
assert "Customer3" in c.labels()
+@mark_sync_test
def test_refresh():
c = Customer2(email="my@email.com", age=16).save()
c.my_custom_prop = "value"
@@ -210,7 +218,8 @@ def test_refresh():
assert c.age == 20
- if db.database_version.startswith("4"):
+ _db_version = db.database_version
+ if _db_version.startswith("4"):
c = Customer2.inflate(999)
else:
c = Customer2.inflate("4:xxxxxx:999")
@@ -218,6 +227,7 @@ def test_refresh():
c.refresh()
+@mark_sync_test
def test_setting_value_to_none():
c = Customer2(email="alice@bob.com", age=42).save()
assert c.age is not None
@@ -229,6 +239,7 @@ def test_setting_value_to_none():
assert copy.age is None
+@mark_sync_test
def test_inheritance():
class User(StructuredNode):
__abstract_node__ = True
@@ -248,9 +259,10 @@ def credit_account(self, amount):
assert jim.balance == 350
assert len(jim.inherited_labels()) == 1
assert len(jim.labels()) == 1
- assert jim.labels()[0] == "Shopper"
+ assert (jim.labels())[0] == "Shopper"
+@mark_sync_test
def test_inherited_optional_labels():
class BaseOptional(StructuredNode):
__optional_labels__ = ["Alive"]
@@ -275,6 +287,7 @@ def credit_account(self, amount):
assert set(henry.inherited_optional_labels()) == {"Alive", "RewardsMember"}
+@mark_sync_test
def test_mixins():
class UserMixin:
name = StringProperty(unique_index=True)
@@ -297,9 +310,10 @@ class Shopper2(StructuredNode, UserMixin, CreditMixin):
assert jim.balance == 350
assert len(jim.inherited_labels()) == 1
assert len(jim.labels()) == 1
- assert jim.labels()[0] == "Shopper2"
+ assert (jim.labels())[0] == "Shopper2"
+@mark_sync_test
def test_date_property():
class DateTest(StructuredNode):
birthdate = DateProperty()
diff --git a/test/test_multiprocessing.py b/test/sync_/test_multiprocessing.py
similarity index 78%
rename from test/test_multiprocessing.py
rename to test/sync_/test_multiprocessing.py
index fb00675d..2d9167f9 100644
--- a/test/test_multiprocessing.py
+++ b/test/sync_/test_multiprocessing.py
@@ -1,4 +1,5 @@
from multiprocessing.pool import ThreadPool as Pool
+from test._async_compat import mark_sync_test
from neomodel import StringProperty, StructuredNode, db
@@ -13,9 +14,11 @@ def thing_create(name):
return thing.name, name
+@mark_sync_test
def test_concurrency():
with Pool(5) as p:
results = p.map(thing_create, range(50))
- for returned, sent in results:
+ for to_unpack in results:
+ returned, sent = to_unpack
assert returned == sent
db.close_connection()
diff --git a/test/test_paths.py b/test/sync_/test_paths.py
similarity index 67%
rename from test/test_paths.py
rename to test/sync_/test_paths.py
index 8c6fef28..8e0ccf90 100644
--- a/test/test_paths.py
+++ b/test/sync_/test_paths.py
@@ -1,41 +1,56 @@
-from neomodel import (StringProperty, StructuredNode, UniqueIdProperty,
- db, RelationshipTo, IntegerProperty, NeomodelPath, StructuredRel)
+from test._async_compat import mark_sync_test
+
+from neomodel import (
+ IntegerProperty,
+ NeomodelPath,
+ RelationshipTo,
+ StringProperty,
+ StructuredNode,
+ StructuredRel,
+ UniqueIdProperty,
+ db,
+)
+
class PersonLivesInCity(StructuredRel):
"""
Relationship with data that will be instantiated as "stand-alone"
"""
+
some_num = IntegerProperty(index=True, default=12)
+
class CountryOfOrigin(StructuredNode):
code = StringProperty(unique_index=True, required=True)
+
class CityOfResidence(StructuredNode):
name = StringProperty(required=True)
- country = RelationshipTo(CountryOfOrigin, 'FROM_COUNTRY')
+ country = RelationshipTo(CountryOfOrigin, "FROM_COUNTRY")
+
class PersonOfInterest(StructuredNode):
uid = UniqueIdProperty()
name = StringProperty(unique_index=True)
age = IntegerProperty(index=True, default=0)
- country = RelationshipTo(CountryOfOrigin, 'IS_FROM')
- city = RelationshipTo(CityOfResidence, 'LIVES_IN', model=PersonLivesInCity)
+ country = RelationshipTo(CountryOfOrigin, "IS_FROM")
+ city = RelationshipTo(CityOfResidence, "LIVES_IN", model=PersonLivesInCity)
+@mark_sync_test
def test_path_instantiation():
"""
- Neo4j driver paths should be instantiated as neomodel paths, with all of
- their nodes and relationships resolved to their Python objects wherever
+ Neo4j driver paths should be instantiated as neomodel paths, with all of
+ their nodes and relationships resolved to their Python objects wherever
such a mapping is available.
"""
- c1=CountryOfOrigin(code="GR").save()
- c2=CountryOfOrigin(code="FR").save()
-
- ct1 = CityOfResidence(name="Athens", country = c1).save()
- ct2 = CityOfResidence(name="Paris", country = c2).save()
+ c1 = CountryOfOrigin(code="GR").save()
+ c2 = CountryOfOrigin(code="FR").save()
+ ct1 = CityOfResidence(name="Athens", country=c1).save()
+ ct2 = CityOfResidence(name="Paris", country=c2).save()
p1 = PersonOfInterest(name="Bill", age=22).save()
p1.country.connect(c1)
@@ -54,7 +69,10 @@ def test_path_instantiation():
p4.city.connect(ct2)
# Retrieve a single path
- q = db.cypher_query("MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1", resolve_objects = True)
+ q = db.cypher_query(
+ "MATCH p=(:CityOfResidence)<-[:LIVES_IN]-(:PersonOfInterest)-[:IS_FROM]->(:CountryOfOrigin) RETURN p LIMIT 1",
+ resolve_objects=True,
+ )
path_object = q[0][0][0]
path_nodes = path_object.nodes
@@ -76,4 +94,3 @@ def test_path_instantiation():
p2.delete()
p3.delete()
p4.delete()
-
diff --git a/test/test_properties.py b/test/sync_/test_properties.py
similarity index 94%
rename from test/test_properties.py
rename to test/sync_/test_properties.py
index 454ada26..28866738 100644
--- a/test/test_properties.py
+++ b/test/sync_/test_properties.py
@@ -1,9 +1,10 @@
from datetime import date, datetime
+from test._async_compat import mark_sync_test
from pytest import mark, raises
from pytz import timezone
-from neomodel import StructuredNode, config, db
+from neomodel import StructuredNode, db
from neomodel.exceptions import (
DeflateError,
InflateError,
@@ -25,8 +26,6 @@
)
from neomodel.util import _get_node_properties
-config.AUTO_INSTALL_LABELS = True
-
class FooBar:
pass
@@ -60,6 +59,7 @@ def test_string_property_exceeds_max_length():
), "StringProperty max_length test passed but values do not match."
+@mark_sync_test
def test_string_property_w_choice():
class TestChoices(StructuredNode):
SEXES = {"F": "Female", "M": "Male", "O": "Other"}
@@ -84,7 +84,6 @@ def test_deflate_inflate():
try:
prop.inflate("six")
except InflateError as e:
- assert True
assert "inflate property" in str(e)
else:
assert False, "DeflateError not raised."
@@ -185,6 +184,7 @@ def test_json():
assert prop.inflate('{"test": [1, 2, 3]}') == value
+@mark_sync_test
def test_default_value():
class DefaultTestValue(StructuredNode):
name_xx = StringProperty(default="jim", index=True)
@@ -194,6 +194,7 @@ class DefaultTestValue(StructuredNode):
a.save()
+@mark_sync_test
def test_default_value_callable():
def uid_generator():
return "xx"
@@ -205,6 +206,7 @@ class DefaultTestValueTwo(StructuredNode):
assert a.uid == "xx"
+@mark_sync_test
def test_default_value_callable_type():
# check our object gets converted to str without serializing and reload
def factory():
@@ -225,6 +227,7 @@ class DefaultTestValueThree(StructuredNode):
assert x.uid == "123"
+@mark_sync_test
def test_independent_property_name():
class TestDBNamePropertyNode(StructuredNode):
name_ = StringProperty(db_property="name")
@@ -242,12 +245,13 @@ class TestDBNamePropertyNode(StructuredNode):
assert not "name_" in node_properties
assert not hasattr(x, "name")
assert hasattr(x, "name_")
- assert TestDBNamePropertyNode.nodes.filter(name_="jim").all()[0].name_ == x.name_
- assert TestDBNamePropertyNode.nodes.get(name_="jim").name_ == x.name_
+ assert (TestDBNamePropertyNode.nodes.filter(name_="jim").all())[0].name_ == x.name_
+ assert (TestDBNamePropertyNode.nodes.get(name_="jim")).name_ == x.name_
x.delete()
+@mark_sync_test
def test_independent_property_name_get_or_create():
class TestNode(StructuredNode):
uid = UniqueIdProperty()
@@ -256,10 +260,10 @@ class TestNode(StructuredNode):
# create the node
TestNode.get_or_create({"uid": 123, "name_": "jim"})
# test that the node is retrieved correctly
- x = TestNode.get_or_create({"uid": 123, "name_": "jim"})[0]
+ x = (TestNode.get_or_create({"uid": 123, "name_": "jim"}))[0]
# check database property name on low level
- results, meta = db.cypher_query("MATCH (n:TestNode) RETURN n")
+ results, _ = db.cypher_query("MATCH (n:TestNode) RETURN n")
node_properties = _get_node_properties(results[0][0])
assert node_properties["name"] == "jim"
assert "name_" not in node_properties
@@ -331,6 +335,7 @@ def test_email_property():
prop.deflate("foo@example")
+@mark_sync_test
def test_uid_property():
prop = UniqueIdProperty()
prop.name = "uid"
@@ -351,6 +356,7 @@ class ArrayProps(StructuredNode):
typed_arr = ArrayProperty(IntegerProperty())
+@mark_sync_test
def test_array_properties():
# untyped
ap1 = ArrayProps(uid="1", untyped_arr=["Tim", "Bob"]).save()
@@ -377,6 +383,7 @@ def test_illegal_array_base_prop_raises():
ArrayProperty(StringProperty(index=True))
+@mark_sync_test
def test_indexed_array():
class IndexArray(StructuredNode):
ai = ArrayProperty(unique_index=True)
@@ -386,6 +393,7 @@ class IndexArray(StructuredNode):
assert b.element_id == c.element_id
+@mark_sync_test
def test_unique_index_prop_not_required():
class ConstrainedTestNode(StructuredNode):
required_property = StringProperty(required=True)
@@ -414,10 +422,12 @@ class ConstrainedTestNode(StructuredNode):
x.delete()
+@mark_sync_test
def test_unique_index_prop_enforced():
class UniqueNullableNameNode(StructuredNode):
name = StringProperty(unique_index=True)
+ db.install_labels(UniqueNullableNameNode)
# Nameless
x = UniqueNullableNameNode()
x.save()
@@ -432,7 +442,7 @@ class UniqueNullableNameNode(StructuredNode):
a.save()
# Check nodes are in database
- results, meta = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n")
+ results, _ = db.cypher_query("MATCH (n:UniqueNullableNameNode) RETURN n")
assert len(results) == 3
# Delete nodes afterwards
diff --git a/test/test_relationship_models.py b/test/sync_/test_relationship_models.py
similarity index 90%
rename from test/test_relationship_models.py
rename to test/sync_/test_relationship_models.py
index 82760c73..837f53d6 100644
--- a/test/test_relationship_models.py
+++ b/test/sync_/test_relationship_models.py
@@ -1,4 +1,5 @@
from datetime import datetime
+from test._async_compat import mark_sync_test
import pytz
from pytest import raises
@@ -12,6 +13,7 @@
StructuredNode,
StructuredRel,
)
+from neomodel._async_compat.util import Util
HOOKS_CALLED = {"pre_save": 0, "post_save": 0}
@@ -41,6 +43,7 @@ class Stoat(StructuredNode):
hates = RelationshipTo("Badger", "HATES", model=HatesRel)
+@mark_sync_test
def test_either_connect_with_rel_model():
paul = Badger(name="Paul").save()
tom = Badger(name="Tom").save()
@@ -63,6 +66,7 @@ def test_either_connect_with_rel_model():
assert tom.name == "Paul"
+@mark_sync_test
def test_direction_connect_with_rel_model():
paul = Badger(name="Paul the badger").save()
ian = Stoat(name="Ian the stoat").save()
@@ -103,6 +107,7 @@ def test_direction_connect_with_rel_model():
)
+@mark_sync_test
def test_traversal_where_clause():
phill = Badger(name="Phill the badger").save()
tim = Badger(name="Tim the badger").save()
@@ -113,9 +118,10 @@ def test_traversal_where_clause():
rel2 = tim.friend.connect(phill)
assert rel2.since > now
friends = tim.friend.match(since__gt=now)
- assert len(friends) == 1
+ assert len(friends.all()) == 1
+@mark_sync_test
def test_multiple_rels_exist_issue_223():
# check a badger can dislike a stoat for multiple reasons
phill = Badger(name="Phill").save()
@@ -125,11 +131,16 @@ def test_multiple_rels_exist_issue_223():
rel_b = phill.hates.connect(ian, {"reason": "b"})
assert rel_a.element_id != rel_b.element_id
- ian_a = phill.hates.match(reason="a")[0]
- ian_b = phill.hates.match(reason="b")[0]
+ if Util.is_async_code:
+ ian_a = (phill.hates.match(reason="a"))[0]
+ ian_b = (phill.hates.match(reason="b"))[0]
+ else:
+ ian_a = phill.hates.match(reason="a")[0]
+ ian_b = phill.hates.match(reason="b")[0]
assert ian_a.element_id == ian_b.element_id
+@mark_sync_test
def test_retrieve_all_rels():
tom = Badger(name="tom").save()
ian = Stoat(name="ian").save()
@@ -143,6 +154,7 @@ def test_retrieve_all_rels():
assert rels[1].element_id in [rel_a.element_id, rel_b.element_id]
+@mark_sync_test
def test_save_hook_on_rel_model():
HOOKS_CALLED["pre_save"] = 0
HOOKS_CALLED["post_save"] = 0
diff --git a/test/test_relationships.py b/test/sync_/test_relationships.py
similarity index 91%
rename from test/test_relationships.py
rename to test/sync_/test_relationships.py
index 92d75064..39057a39 100644
--- a/test/test_relationships.py
+++ b/test/sync_/test_relationships.py
@@ -1,3 +1,5 @@
+from test._async_compat import mark_sync_test
+
from pytest import raises
from neomodel import (
@@ -10,8 +12,8 @@
StringProperty,
StructuredNode,
StructuredRel,
+ db,
)
-from neomodel.core import db
class PersonWithRels(StructuredNode):
@@ -41,6 +43,7 @@ def special_power(self):
return "I have powers"
+@mark_sync_test
def test_actions_on_deleted_node():
u = PersonWithRels(name="Jim2", age=3).save()
u.delete()
@@ -54,6 +57,7 @@ def test_actions_on_deleted_node():
u.save()
+@mark_sync_test
def test_bidirectional_relationships():
u = PersonWithRels(name="Jim", age=3).save()
assert u
@@ -61,26 +65,27 @@ def test_bidirectional_relationships():
de = Country(code="DE").save()
assert de
- assert not u.is_from
+ assert not u.is_from.all()
assert u.is_from.__class__.__name__ == "ZeroOrMore"
u.is_from.connect(de)
- assert len(u.is_from) == 1
+ assert len(u.is_from.all()) == 1
assert u.is_from.is_connected(de)
- b = u.is_from.all()[0]
+ b = (u.is_from.all())[0]
assert b.__class__.__name__ == "Country"
assert b.code == "DE"
- s = b.inhabitant.all()[0]
+ s = (b.inhabitant.all())[0]
assert s.name == "Jim"
u.is_from.disconnect(b)
assert not u.is_from.is_connected(b)
+@mark_sync_test
def test_either_direction_connect():
rey = PersonWithRels(name="Rey", age=3).save()
sakis = PersonWithRels(name="Sakis", age=3).save()
@@ -94,7 +99,7 @@ def test_either_direction_connect():
f"""MATCH (us), (them)
WHERE {db.get_id_method()}(us)=$self and {db.get_id_method()}(them)=$them
MATCH (us)-[r:KNOWS]-(them) RETURN COUNT(r)""",
- {"them": rey.element_id},
+ {"them": db.parse_element_id(rey.element_id)},
)
assert int(result[0][0]) == 1
@@ -105,6 +110,7 @@ def test_either_direction_connect():
assert isinstance(rels[0], StructuredRel)
+@mark_sync_test
def test_search_and_filter_and_exclude():
fred = PersonWithRels(name="Fred", age=13).save()
zz = Country(code="ZZ").save()
@@ -129,6 +135,7 @@ def test_search_and_filter_and_exclude():
assert len(result) == 3
+@mark_sync_test
def test_custom_methods():
u = PersonWithRels(name="Joe90", age=13).save()
assert u.special_power() == "I have no powers"
@@ -137,6 +144,7 @@ def test_custom_methods():
assert u.special_name == "Joe91"
+@mark_sync_test
def test_valid_reconnection():
p = PersonWithRels(name="ElPresidente", age=93).save()
assert p
@@ -159,6 +167,7 @@ def test_valid_reconnection():
assert c.president.is_connected(pp)
+@mark_sync_test
def test_valid_replace():
brady = PersonWithRels(name="Tom Brady", age=40).save()
assert brady
@@ -174,17 +183,18 @@ def test_valid_replace():
brady.knows.connect(gronk)
brady.knows.connect(colbert)
- assert len(brady.knows) == 2
+ assert len(brady.knows.all()) == 2
assert brady.knows.is_connected(gronk)
assert brady.knows.is_connected(colbert)
brady.knows.replace(hanks)
- assert len(brady.knows) == 1
+ assert len(brady.knows.all()) == 1
assert brady.knows.is_connected(hanks)
assert not brady.knows.is_connected(gronk)
assert not brady.knows.is_connected(colbert)
+@mark_sync_test
def test_props_relationship():
u = PersonWithRels(name="Mar", age=20).save()
assert u
diff --git a/test/test_relative_relationships.py b/test/sync_/test_relative_relationships.py
similarity index 83%
rename from test/test_relative_relationships.py
rename to test/sync_/test_relative_relationships.py
index db78d038..a01e28f9 100644
--- a/test/test_relative_relationships.py
+++ b/test/sync_/test_relative_relationships.py
@@ -1,6 +1,7 @@
-from neomodel import RelationshipTo, StringProperty, StructuredNode
+from test._async_compat import mark_sync_test
+from test.sync_.test_relationships import Country
-from .test_relationships import Country
+from neomodel import RelationshipTo, StringProperty, StructuredNode
class Cat(StructuredNode):
@@ -9,6 +10,7 @@ class Cat(StructuredNode):
is_from = RelationshipTo(".test_relationships.Country", "IS_FROM")
+@mark_sync_test
def test_relative_relationship():
a = Cat(name="snufkin").save()
assert a
diff --git a/test/test_transactions.py b/test/sync_/test_transactions.py
similarity index 82%
rename from test/test_transactions.py
rename to test/sync_/test_transactions.py
index 0481e2a7..834b538e 100644
--- a/test/test_transactions.py
+++ b/test/sync_/test_transactions.py
@@ -1,22 +1,18 @@
+from test._async_compat import mark_sync_test
+
import pytest
from neo4j.api import Bookmarks
from neo4j.exceptions import ClientError, TransactionError
from pytest import raises
-from neomodel import (
- StringProperty,
- StructuredNode,
- UniqueProperty,
- config,
- db,
- install_labels,
-)
+from neomodel import StringProperty, StructuredNode, UniqueProperty, db
class APerson(StructuredNode):
name = StringProperty(unique_index=True)
+@mark_sync_test
def test_rollback_and_commit_transaction():
for p in APerson.nodes:
p.delete()
@@ -42,14 +38,14 @@ def in_a_tx(*names):
APerson(name=n).save()
+@mark_sync_test
def test_transaction_decorator():
- install_labels(APerson)
+ db.install_labels(APerson)
for p in APerson.nodes:
p.delete()
# should work
in_a_tx("Roger")
- assert True
# should bail but raise correct error
with raises(UniqueProperty):
@@ -58,6 +54,7 @@ def test_transaction_decorator():
assert "Jim" not in [p.name for p in APerson.nodes]
+@mark_sync_test
def test_transaction_as_a_context():
with db.transaction:
APerson(name="Tim").save()
@@ -69,6 +66,7 @@ def test_transaction_as_a_context():
APerson(name="Tim").save()
+@mark_sync_test
def test_query_inside_transaction():
for p in APerson.nodes:
p.delete()
@@ -80,11 +78,12 @@ def test_query_inside_transaction():
assert len([p.name for p in APerson.nodes]) == 2
+@mark_sync_test
def test_read_transaction():
APerson(name="Johnny").save()
with db.read_transaction:
- people = APerson.nodes.all()
+ people = APerson.nodes
assert people
with raises(TransactionError):
@@ -94,6 +93,7 @@ def test_read_transaction():
assert e.value.code == "Neo.ClientError.Statement.AccessMode"
+@mark_sync_test
def test_write_transaction():
with db.write_transaction:
APerson(name="Amelia").save()
@@ -102,6 +102,7 @@ def test_write_transaction():
assert amelia
+@mark_sync_test
def double_transaction():
db.begin()
with raises(SystemError, match=r"Transaction in progress"):
@@ -111,28 +112,30 @@ def double_transaction():
@db.transaction.with_bookmark
-def in_a_tx(*names):
+def in_a_tx_with_bookmark(*names):
for n in names:
APerson(name=n).save()
-def test_bookmark_transaction_decorator(skip_neo4j_before_330):
+@mark_sync_test
+def test_bookmark_transaction_decorator():
for p in APerson.nodes:
p.delete()
# should work
- result, bookmarks = in_a_tx("Ruth", bookmarks=None)
+ result, bookmarks = in_a_tx_with_bookmark("Ruth", bookmarks=None)
assert result is None
assert isinstance(bookmarks, Bookmarks)
# should bail but raise correct error
with raises(UniqueProperty):
- in_a_tx("Jane", "Ruth")
+ in_a_tx_with_bookmark("Jane", "Ruth")
assert "Jane" not in [p.name for p in APerson.nodes]
-def test_bookmark_transaction_as_a_context(skip_neo4j_before_330):
+@mark_sync_test
+def test_bookmark_transaction_as_a_context():
with db.transaction as transaction:
APerson(name="Tanya").save()
assert isinstance(transaction.last_bookmark, Bookmarks)
@@ -158,12 +161,13 @@ def begin_spy(*args, **kwargs):
return spy_calls
-def test_bookmark_passed_in_to_context(skip_neo4j_before_330, spy_on_db_begin):
+@mark_sync_test
+def test_bookmark_passed_in_to_context(spy_on_db_begin):
transaction = db.transaction
with transaction:
pass
- assert spy_on_db_begin[-1] == ((), {"access_mode": None, "bookmarks": None})
+ assert (spy_on_db_begin)[-1] == ((), {"access_mode": None, "bookmarks": None})
last_bookmark = transaction.last_bookmark
transaction.bookmarks = last_bookmark
@@ -175,7 +179,8 @@ def test_bookmark_passed_in_to_context(skip_neo4j_before_330, spy_on_db_begin):
)
-def test_query_inside_bookmark_transaction(skip_neo4j_before_330):
+@mark_sync_test
+def test_query_inside_bookmark_transaction():
for p in APerson.nodes:
p.delete()
diff --git a/test/test_scripts.py b/test/test_scripts.py
index e30fd213..5583af47 100644
--- a/test/test_scripts.py
+++ b/test/test_scripts.py
@@ -8,10 +8,8 @@
StructuredNode,
StructuredRel,
config,
- db,
- install_labels,
- util,
)
+from neomodel.sync_.core import db
class ScriptsTestRel(StructuredRel):
@@ -108,9 +106,9 @@ def test_neomodel_inspect_database(script_flavour):
assert "usage: neomodel_inspect_database" in result.stdout
assert result.returncode == 0
- util.clear_neo4j_database(db)
- install_labels(ScriptsTestNode)
- install_labels(ScriptsTestRel)
+ db.clear_neo4j_database()
+ db.install_labels(ScriptsTestNode)
+ db.install_labels(ScriptsTestRel)
# Create a few nodes and a rel, with indexes and constraints
node1 = ScriptsTestNode(personal_id="1", name="test").save()